Skip to content

Commit

Permalink
[AIRFLOW-5391] Do not re-run skipped tasks when they are cleared (#7276)
Browse files Browse the repository at this point in the history
If a task is skipped by BranchPythonOperator, BaseBranchOperator or ShortCircuitOperator and the user then clears the skipped task later, it'll execute. This is probably not the right
behaviour.

This commit changes that so it will be skipped again. This can be ignored by running the task again with "Ignore Task Deps" override.

(cherry picked from commit 1cdab56)
  • Loading branch information
yuqian90 authored and kaxil committed Jul 22, 2020
1 parent d23fa2f commit 17a9a5b
Show file tree
Hide file tree
Showing 10 changed files with 571 additions and 56 deletions.
7 changes: 7 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ assists users migrating to a new version.
<!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->
**Table of contents**

- [Airflow 1.10.12](#airflow-11012)
- [Airflow 1.10.11](#airflow-11011)
- [Airflow 1.10.10](#airflow-11010)
- [Airflow 1.10.9](#airflow-1109)
Expand Down Expand Up @@ -59,6 +60,12 @@ More tips can be found in the guide:
https://developers.google.com/style/inclusive-documentation
-->
## Airflow 1.10.12

### Clearing tasks skipped by SkipMixin will skip them

Previously, when tasks skipped by SkipMixin (such as BranchPythonOperator, BaseBranchOperator and ShortCircuitOperator) are cleared, they execute. Since 1.10.12, when such skipped tasks are cleared,
they will be skipped again by the newly introduced NotPreviouslySkippedDep.

## Airflow 1.10.11

Expand Down
2 changes: 2 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
Expand Down Expand Up @@ -575,6 +576,7 @@ def deps(self):
NotInRetryPeriodDep(),
PrevDagrunDep(),
TriggerRuleDep(),
NotPreviouslySkippedDep(),
}

@property
Expand Down
112 changes: 83 additions & 29 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,44 @@

from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.db import provide_session
from airflow.utils.db import create_session, provide_session
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State

import six
from typing import Union, Iterable, Set
from typing import Set

# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"

# The dictionary key used to denote task IDs that are skipped
XCOM_SKIPMIXIN_SKIPPED = "skipped"

# The dictionary key used to denote task IDs that are followed
XCOM_SKIPMIXIN_FOLLOWED = "followed"


class SkipMixin(LoggingMixin):
@provide_session
def skip(self, dag_run, execution_date, tasks, session=None):
def _set_state_to_skipped(self, dag_run, execution_date, tasks, session):
"""
Sets tasks instances to skipped from the same dag run.
:param dag_run: the DagRun for which to set the tasks to skipped
:param execution_date: execution_date
:param tasks: tasks to skip (not task_ids)
:param session: db session to use
Used internally to set state of task instances to skipped from the same dag run.
"""
if not tasks:
return

task_ids = [d.task_id for d in tasks]
now = timezone.utcnow()

if dag_run:
session.query(TaskInstance).filter(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.execution_date == dag_run.execution_date,
TaskInstance.task_id.in_(task_ids)
).update({TaskInstance.state: State.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now},
synchronize_session=False)
session.commit()
TaskInstance.task_id.in_(task_ids),
).update(
{
TaskInstance.state: State.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now,
},
synchronize_session=False,
)
else:
assert execution_date is not None, "Execution date is None and no dag run"

Expand All @@ -66,14 +69,56 @@ def skip(self, dag_run, execution_date, tasks, session=None):
ti.end_date = now
session.merge(ti)

session.commit()
@provide_session
def skip(
self, dag_run, execution_date, tasks, session=None,
):
"""
Sets tasks instances to skipped from the same dag run.
If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
so that NotPreviouslySkippedDep knows these tasks should be skipped when they
are cleared.
def skip_all_except(self, ti, branch_task_ids):
# type: (TaskInstance, Union[str, Iterable[str]]) -> None
:param dag_run: the DagRun for which to set the tasks to skipped
:param execution_date: execution_date
:param tasks: tasks to skip (not task_ids)
:param session: db session to use
"""
if not tasks:
return

self._set_state_to_skipped(dag_run, execution_date, tasks, session)
session.commit()

# SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
try:
task_id = self.task_id
except AttributeError:
task_id = None

if task_id is not None:
from airflow.models.xcom import XCom

XCom.set(
key=XCOM_SKIPMIXIN_KEY,
value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in tasks]},
task_id=task_id,
dag_id=dag_run.dag_id,
execution_date=dag_run.execution_date,
session=session
)

def skip_all_except(
self, ti, branch_task_ids
):
"""
This method implements the logic for a branching operator; given a single
task ID or list of task IDs to follow, this skips all other tasks
immediately downstream of this operator.
branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
newly added tasks should be skipped when they are cleared.
"""
self.log.info("Following branch %s", branch_task_ids)
if isinstance(branch_task_ids, six.string_types):
Expand All @@ -90,13 +135,22 @@ def skip_all_except(self, ti, branch_task_ids):
# is also a downstream task of the branch task, we exclude it from skipping.
branch_downstream_task_ids = set() # type: Set[str]
for b in branch_task_ids:
branch_downstream_task_ids.update(dag.
get_task(b).
get_flat_relative_ids(upstream=False))
branch_downstream_task_ids.update(
dag.get_task(b).get_flat_relative_ids(upstream=False)
)

skip_tasks = [t for t in downstream_tasks
if t.task_id not in branch_task_ids and
t.task_id not in branch_downstream_task_ids]
skip_tasks = [
t
for t in downstream_tasks
if t.task_id not in branch_task_ids
and t.task_id not in branch_downstream_task_ids
]

self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
self.skip(dag_run, ti.execution_date, skip_tasks)
with create_session() as session:
self._set_state_to_skipped(
dag_run, ti.execution_date, skip_tasks, session=session
)
ti.xcom_push(
key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids}
)
27 changes: 26 additions & 1 deletion airflow/ti_deps/dep_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class DepContext(object):
:type ignore_task_deps: bool
:param ignore_ti_state: Ignore the task instance's previous failure/success
:type ignore_ti_state: bool
:param finished_tasks: A list of all the finished tasks of this run
:type finished_tasks: list[airflow.models.TaskInstance]
"""
def __init__(
self,
Expand All @@ -76,7 +78,8 @@ def __init__(
ignore_in_retry_period=False,
ignore_in_reschedule_period=False,
ignore_task_deps=False,
ignore_ti_state=False):
ignore_ti_state=False,
finished_tasks=None):
self.deps = deps or set()
self.flag_upstream_failed = flag_upstream_failed
self.ignore_all_deps = ignore_all_deps
Expand All @@ -85,6 +88,28 @@ def __init__(
self.ignore_in_reschedule_period = ignore_in_reschedule_period
self.ignore_task_deps = ignore_task_deps
self.ignore_ti_state = ignore_ti_state
self.finished_tasks = finished_tasks

def ensure_finished_tasks(self, dag, execution_date, session):
"""
This method makes sure finished_tasks is populated if it's currently None.
This is for the strange feature of running tasks without dag_run.
:param dag: The DAG for which to find finished tasks
:type dag: airflow.models.DAG
:param execution_date: The execution_date to look for
:param session: Database session to use
:return: A list of all the finished tasks of this DAG and execution_date
:rtype: list[airflow.models.TaskInstance]
"""
if self.finished_tasks is None:
self.finished_tasks = dag.get_task_instances(
start_date=execution_date,
end_date=execution_date,
state=State.finished() + [State.UPSTREAM_FAILED],
session=session,
)
return self.finished_tasks


# In order to be able to get queued a task must have one of these states
Expand Down
88 changes: 88 additions & 0 deletions airflow/ti_deps/deps/not_previously_skipped_dep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.ti_deps.deps.base_ti_dep import BaseTIDep


class NotPreviouslySkippedDep(BaseTIDep):
"""
Determines if any of the task's direct upstream relatives have decided this task should
be skipped.
"""

NAME = "Not Previously Skipped"
IGNORABLE = True
IS_TASK_DEP = True

def _get_dep_statuses(
self, ti, session, dep_context
): # pylint: disable=signature-differs
from airflow.models.skipmixin import (
SkipMixin,
XCOM_SKIPMIXIN_KEY,
XCOM_SKIPMIXIN_SKIPPED,
XCOM_SKIPMIXIN_FOLLOWED,
)
from airflow.utils.state import State

upstream = ti.task.get_direct_relatives(upstream=True)

finished_tasks = dep_context.ensure_finished_tasks(
ti.task.dag, ti.execution_date, session
)

finished_task_ids = {t.task_id for t in finished_tasks}

for parent in upstream:
if isinstance(parent, SkipMixin):
if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue

prev_result = ti.xcom_pull(
task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY
)

if prev_result is None:
# This can happen if the parent task has not yet run.
continue

should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif (
XCOM_SKIPMIXIN_SKIPPED in prev_result
and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]
):
# Skip any tasks that are in "skipped"
should_skip = True

if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
ti.set_state(State.SKIPPED, session)
yield self._failing_status(
reason="Skipping because of previous XCom result from parent task {}"
.format(parent.task_id)
)
return
39 changes: 39 additions & 0 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3011,3 +3011,42 @@ def test_should_mark_dummy_task_as_success(self):
self.assertIsNone(start_date)
self.assertIsNone(end_date)
self.assertIsNone(duration)


def test_task_with_upstream_skip_process_task_instances():
"""
Test if _process_task_instances puts a task instance into SKIPPED state if any of its
upstream tasks are skipped according to TriggerRuleDep.
"""
with DAG(
dag_id='test_task_with_upstream_skip_dag',
start_date=DEFAULT_DATE,
schedule_interval=None
) as dag:
dummy1 = DummyOperator(task_id='dummy1')
dummy2 = DummyOperator(task_id="dummy2")
dummy3 = DummyOperator(task_id="dummy3")
[dummy1, dummy2] >> dummy3

dag_file_processor = SchedulerJob(dag_ids=[], log=mock.MagicMock())
dag.clear()
dr = dag.create_dagrun(run_id="manual__{}".format(DEFAULT_DATE.isoformat()),
state=State.RUNNING,
execution_date=DEFAULT_DATE)
assert dr is not None

with create_session() as session:
tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)}
# Set dummy1 to skipped and dummy2 to success. dummy3 remains as none.
tis[dummy1.task_id].state = State.SKIPPED
tis[dummy2.task_id].state = State.SUCCESS
assert tis[dummy3.task_id].state == State.NONE

dag_file_processor._process_task_instances(dag, task_instances_list=Mock())

with create_session() as session:
tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)}
assert tis[dummy1.task_id].state == State.SKIPPED
assert tis[dummy2.task_id].state == State.SUCCESS
# dummy3 should be skipped because dummy1 is skipped.
assert tis[dummy3.task_id].state == State.SKIPPED
Loading

0 comments on commit 17a9a5b

Please sign in to comment.