Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up clear_task_instances by doing a single sql delete for TaskReschedule #14048

Merged
merged 9 commits into from
Feb 10, 2021
37 changes: 31 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pickle
import signal
import warnings
from collections import defaultdict
from datetime import datetime, timedelta
from tempfile import NamedTemporaryFile
from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
Expand Down Expand Up @@ -146,6 +147,7 @@ def clear_task_instances(
:param dag: DAG object
"""
job_ids = []
task_id_by_key = defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
for ti in tis:
if ti.state == State.RUNNING:
if ti.job_id:
Expand All @@ -166,13 +168,36 @@ def clear_task_instances(
ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries)
ti.state = State.NONE
session.merge(ti)

task_id_by_key[ti.dag_id][ti.execution_date][ti.try_number].add(ti.task_id)

if task_id_by_key:
# Clear all reschedules related to the ti to clear
session.query(TR).filter(
TR.dag_id == ti.dag_id,
TR.task_id == ti.task_id,
TR.execution_date == ti.execution_date,
TR.try_number == ti.try_number,
).delete()

# This is an optimization for the common case where all tis are for a small number
# of dag_id, execution_date and try_number. Use a nested dict of dag_id,
# execution_date, try_number and task_id to construct the where clause in a
# hierarchical manner. This speeds up the delete statement by more than 40x for
# large number of tis (50k+).
conditions = or_(
and_(
TR.dag_id == dag_id,
or_(
and_(
TR.execution_date == execution_date,
or_(
and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
for try_number, task_ids in task_tries.items()
),
)
for execution_date, task_tries in dates.items()
),
)
for dag_id, dates in task_id_by_key.items()
)

delete_qry = TR.__table__.delete().where(conditions)
session.execute(delete_qry)

if job_ids:
from airflow.jobs.base_job import BaseJob
Expand Down
47 changes: 46 additions & 1 deletion tests/models/test_cleartasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import unittest

from airflow import settings
from airflow.models import DAG, TaskInstance as TI, clear_task_instances
from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances
from airflow.operators.dummy import DummyOperator
from airflow.sensors.python import PythonSensor
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -138,6 +139,50 @@ def test_clear_task_instances_without_dag(self):
assert ti1.try_number == 2
assert ti1.max_tries == 2

def test_clear_task_instances_with_task_reschedule(self):
"""Test that TaskReschedules are deleted correctly when TaskInstances are cleared"""

with DAG(
'test_clear_task_instances_with_task_reschedule',
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
) as dag:
task0 = PythonSensor(task_id='0', python_callable=lambda: False, mode="reschedule")
task1 = PythonSensor(task_id='1', python_callable=lambda: False, mode="reschedule")

ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
ti1 = TI(task=task1, execution_date=DEFAULT_DATE)

dag.create_dagrun(
execution_date=ti0.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)

ti0.run()
ti1.run()

with create_session() as session:

def count_task_reschedule(task_id):
return (
session.query(TaskReschedule)
.filter(
TaskReschedule.dag_id == dag.dag_id,
TaskReschedule.task_id == task_id,
TaskReschedule.execution_date == DEFAULT_DATE,
TaskReschedule.try_number == 1,
)
.count()
)

assert count_task_reschedule(ti0.task_id) == 1
assert count_task_reschedule(ti1.task_id) == 1
qry = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id).all()
clear_task_instances(qry, session, dag=dag)
assert count_task_reschedule(ti0.task_id) == 0
assert count_task_reschedule(ti1.task_id) == 1

def test_dag_clear(self):
dag = DAG(
'test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)
Expand Down