Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-5391] Do not run skipped tasks when they are cleared #7276

Merged
merged 2 commits into from
Feb 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
Expand Down Expand Up @@ -650,6 +651,7 @@ def deps(self) -> Set[BaseTIDep]:
NotInRetryPeriodDep(),
PrevDagrunDep(),
TriggerRuleDep(),
NotPreviouslySkippedDep(),
}

@property
Expand Down
109 changes: 82 additions & 27 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,40 @@
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State

# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"

# The dictionary key used to denote task IDs that are skipped
XCOM_SKIPMIXIN_SKIPPED = "skipped"

# The dictionary key used to denote task IDs that are followed
XCOM_SKIPMIXIN_FOLLOWED = "followed"


class SkipMixin(LoggingMixin):
@provide_session
def skip(self, dag_run, execution_date, tasks, session=None):
def _set_state_to_skipped(self, dag_run, execution_date, tasks, session):
"""
Sets tasks instances to skipped from the same dag run.

:param dag_run: the DagRun for which to set the tasks to skipped
:param execution_date: execution_date
:param tasks: tasks to skip (not task_ids)
:param session: db session to use
Used internally to set state of task instances to skipped from the same dag run.
"""
if not tasks:
return

task_ids = [d.task_id for d in tasks]
now = timezone.utcnow()

if dag_run:
session.query(TaskInstance).filter(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.execution_date == dag_run.execution_date,
TaskInstance.task_id.in_(task_ids)
).update({TaskInstance.state: State.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now},
synchronize_session=False)
session.commit()
yuqian90 marked this conversation as resolved.
Show resolved Hide resolved
TaskInstance.task_id.in_(task_ids),
).update(
{
TaskInstance.state: State.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now,
},
synchronize_session=False,
)
else:
if execution_date is None:
raise ValueError("Execution date is None and no dag run")
Expand All @@ -65,13 +68,56 @@ def skip(self, dag_run, execution_date, tasks, session=None):
ti.end_date = now
session.merge(ti)

session.commit()
@provide_session
def skip(
self, dag_run, execution_date, tasks, session=None,
):
"""
Sets tasks instances to skipped from the same dag run.

If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
so that NotPreviouslySkippedDep knows these tasks should be skipped when they
are cleared.

def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]):
:param dag_run: the DagRun for which to set the tasks to skipped
:param execution_date: execution_date
:param tasks: tasks to skip (not task_ids)
:param session: db session to use
"""
if not tasks:
return

self._set_state_to_skipped(dag_run, execution_date, tasks, session)
session.commit()

# SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
try:
task_id = self.task_id
except AttributeError:
task_id = None

if task_id is not None:
from airflow.models.xcom import XCom

XCom.set(
key=XCOM_SKIPMIXIN_KEY,
value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in tasks]},
task_id=task_id,
dag_id=dag_run.dag_id,
execution_date=dag_run.execution_date,
yuqian90 marked this conversation as resolved.
Show resolved Hide resolved
session=session
)

def skip_all_except(
self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]
):
"""
This method implements the logic for a branching operator; given a single
task ID or list of task IDs to follow, this skips all other tasks
immediately downstream of this operator.

branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
newly added tasks should be skipped when they are cleared.
"""
self.log.info("Following branch %s", branch_task_ids)
if isinstance(branch_task_ids, str):
Expand All @@ -88,13 +134,22 @@ def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable
# is also a downstream task of the branch task, we exclude it from skipping.
branch_downstream_task_ids = set() # type: Set[str]
for b in branch_task_ids:
branch_downstream_task_ids.update(dag.
get_task(b).
get_flat_relative_ids(upstream=False))
branch_downstream_task_ids.update(
dag.get_task(b).get_flat_relative_ids(upstream=False)
)

skip_tasks = [t for t in downstream_tasks
if t.task_id not in branch_task_ids and
t.task_id not in branch_downstream_task_ids]
skip_tasks = [
t
for t in downstream_tasks
if t.task_id not in branch_task_ids
and t.task_id not in branch_downstream_task_ids
]

self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
self.skip(dag_run, ti.execution_date, skip_tasks)
with create_session() as session:
self._set_state_to_skipped(
dag_run, ti.execution_date, skip_tasks, session=session
)
ti.xcom_push(
key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids}
)
24 changes: 24 additions & 0 deletions airflow/ti_deps/dep_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# specific language governing permissions and limitations
# under the License.

import pendulum
from sqlalchemy.orm.session import Session

from airflow.ti_deps.deps.dag_ti_slots_available_dep import DagTISlotsAvailableDep
from airflow.ti_deps.deps.dag_unpaused_dep import DagUnpausedDep
from airflow.ti_deps.deps.dagrun_exists_dep import DagrunRunningDep
Expand Down Expand Up @@ -90,6 +93,27 @@ def __init__(
self.ignore_ti_state = ignore_ti_state
self.finished_tasks = finished_tasks

def ensure_finished_tasks(self, dag, execution_date: pendulum.datetime, session: Session):
"""
This method makes sure finished_tasks is populated if it's currently None.
This is for the strange feature of running tasks without dag_run.

:param dag: The DAG for which to find finished tasks
:type dag: airflow.models.DAG
:param execution_date: The execution_date to look for
:param session: Database session to use
:return: A list of all the finished tasks of this DAG and execution_date
:rtype: list[airflow.models.TaskInstance]
"""
if self.finished_tasks is None:
self.finished_tasks = dag.get_task_instances(
start_date=execution_date,
end_date=execution_date,
state=State.finished() + [State.UPSTREAM_FAILED],
session=session,
)
return self.finished_tasks


# In order to be able to get queued a task must have one of these states
SCHEDULEABLE_STATES = {
Expand Down
87 changes: 87 additions & 0 deletions airflow/ti_deps/deps/not_previously_skipped_dep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.ti_deps.deps.base_ti_dep import BaseTIDep


class NotPreviouslySkippedDep(BaseTIDep):
"""
Determines if any of the task's direct upstream relatives have decided this task should
be skipped.
"""

NAME = "Not Previously Skipped"
IGNORABLE = True
IS_TASK_DEP = True

def _get_dep_statuses(
self, ti, session, dep_context
): # pylint: disable=signature-differs
from airflow.models.skipmixin import (
SkipMixin,
XCOM_SKIPMIXIN_KEY,
XCOM_SKIPMIXIN_SKIPPED,
XCOM_SKIPMIXIN_FOLLOWED,
)
from airflow.utils.state import State

upstream = ti.task.get_direct_relatives(upstream=True)

finished_tasks = dep_context.ensure_finished_tasks(
ti.task.dag, ti.execution_date, session
)

finished_task_ids = {t.task_id for t in finished_tasks}

for parent in upstream:
if isinstance(parent, SkipMixin):
if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue

prev_result = ti.xcom_pull(
task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY
)

if prev_result is None:
# This can happen if the parent task has not yet run.
continue

should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif (
XCOM_SKIPMIXIN_SKIPPED in prev_result
and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]
):
# Skip any tasks that are in "skipped"
should_skip = True

if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
ti.set_state(State.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
)
return
9 changes: 1 addition & 8 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ def _get_states_count_upstream_ti(ti, finished_tasks, session):
:param finished_tasks: all the finished tasks of the dag_run
:type finished_tasks: list[airflow.models.TaskInstance]
"""
if finished_tasks is None:
# this is for the strange feature of running tasks without dag_run
finished_tasks = ti.task.dag.get_task_instances(
start_date=ti.execution_date,
end_date=ti.execution_date,
state=State.finished() + [State.UPSTREAM_FAILED],
session=session)
counter = Counter(task.state for task in finished_tasks if task.task_id in ti.task.upstream_task_ids)
return counter.get(State.SUCCESS, 0), counter.get(State.SKIPPED, 0), counter.get(State.FAILED, 0), \
counter.get(State.UPSTREAM_FAILED, 0), sum(counter.values())
Expand All @@ -71,7 +64,7 @@ def _get_dep_statuses(self, ti, session, dep_context):
# see if the task name is in the task upstream for our task
successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti(
ti=ti,
finished_tasks=dep_context.finished_tasks)
finished_tasks=dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session))

yield from self._evaluate_trigger_rule(
ti=ti,
Expand Down
39 changes: 39 additions & 0 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,3 +3021,42 @@ def test_reset_orphaned_tasks_with_orphans(self):
self.assertEqual(state, ti.state)

session.close()


def test_task_with_upstream_skip_process_task_instances():
"""
Test if _process_task_instances puts a task instance into SKIPPED state if any of its
upstream tasks are skipped according to TriggerRuleDep.
"""
with DAG(
dag_id='test_task_with_upstream_skip_dag',
start_date=DEFAULT_DATE,
schedule_interval=None
) as dag:
dummy1 = DummyOperator(task_id='dummy1')
dummy2 = DummyOperator(task_id="dummy2")
dummy3 = DummyOperator(task_id="dummy3")
[dummy1, dummy2] >> dummy3

dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
dag.clear()
dr = dag.create_dagrun(run_id=f"manual__{DEFAULT_DATE.isoformat()}",
state=State.RUNNING,
execution_date=DEFAULT_DATE)
assert dr is not None

with create_session() as session:
tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)}
# Set dummy1 to skipped and dummy2 to success. dummy3 remains as none.
tis[dummy1.task_id].state = State.SKIPPED
tis[dummy2.task_id].state = State.SUCCESS
assert tis[dummy3.task_id].state == State.NONE

dag_file_processor._process_task_instances(dag, task_instances_list=Mock())

with create_session() as session:
tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)}
assert tis[dummy1.task_id].state == State.SKIPPED
assert tis[dummy2.task_id].state == State.SUCCESS
# dummy3 should be skipped because dummy1 is skipped.
assert tis[dummy3.task_id].state == State.SKIPPED
Loading