Skip to content

Commit

Permalink
Speed up clear_task_instances by doing a single sql delete for TaskRe…
Browse files Browse the repository at this point in the history
…schedule (#14048)

Clearing large number of tasks takes a long time. Most of the time is spent at this line in clear_task_instances (more than 95% time). This slowness sometimes causes the webserver to timeout because the web_server_worker_timeout is hit.

```
        # Clear all reschedules related to the ti to clear
        session.query(TR).filter(
            TR.dag_id == ti.dag_id,
            TR.task_id == ti.task_id,
            TR.execution_date == ti.execution_date,
            TR.try_number == ti.try_number,
        ).delete()
```
This line was very slow because it's deleting TaskReschedule rows in a for loop one by one.

This PR simply changes this code to delete TaskReschedule in a single sql query with a bunch of OR conditions. It's effectively doing the same, but now it's much faster.

Some profiling showed great speed improvement (something like 40 to 50 times faster) compared to the first iteration. So the overall performance should now be 300 times faster than the original for loop deletion.

(cherry picked from commit 9036ce2)
  • Loading branch information
yuqian90 authored and ashb committed Mar 18, 2021
1 parent 04ae0f6 commit 118f86c
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 7 deletions.
37 changes: 31 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pickle
import signal
import warnings
from collections import defaultdict
from datetime import datetime, timedelta
from tempfile import NamedTemporaryFile
from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
Expand Down Expand Up @@ -146,6 +147,7 @@ def clear_task_instances(
:param dag: DAG object
"""
job_ids = []
task_id_by_key = defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
for ti in tis:
if ti.state == State.RUNNING:
if ti.job_id:
Expand All @@ -166,13 +168,36 @@ def clear_task_instances(
ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries)
ti.state = State.NONE
session.merge(ti)

task_id_by_key[ti.dag_id][ti.execution_date][ti.try_number].add(ti.task_id)

if task_id_by_key:
# Clear all reschedules related to the ti to clear
session.query(TR).filter(
TR.dag_id == ti.dag_id,
TR.task_id == ti.task_id,
TR.execution_date == ti.execution_date,
TR.try_number == ti.try_number,
).delete()

# This is an optimization for the common case where all tis are for a small number
# of dag_id, execution_date and try_number. Use a nested dict of dag_id,
# execution_date, try_number and task_id to construct the where clause in a
# hierarchical manner. This speeds up the delete statement by more than 40x for
# large number of tis (50k+).
conditions = or_(
and_(
TR.dag_id == dag_id,
or_(
and_(
TR.execution_date == execution_date,
or_(
and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
for try_number, task_ids in task_tries.items()
),
)
for execution_date, task_tries in dates.items()
),
)
for dag_id, dates in task_id_by_key.items()
)

delete_qry = TR.__table__.delete().where(conditions)
session.execute(delete_qry)

if job_ids:
from airflow.jobs.base_job import BaseJob
Expand Down
47 changes: 46 additions & 1 deletion tests/models/test_cleartasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import unittest

from airflow import settings
from airflow.models import DAG, TaskInstance as TI, clear_task_instances
from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances
from airflow.operators.dummy import DummyOperator
from airflow.sensors.python import PythonSensor
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -138,6 +139,50 @@ def test_clear_task_instances_without_dag(self):
assert ti1.try_number == 2
assert ti1.max_tries == 2

def test_clear_task_instances_with_task_reschedule(self):
"""Test that TaskReschedules are deleted correctly when TaskInstances are cleared"""

with DAG(
'test_clear_task_instances_with_task_reschedule',
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
) as dag:
task0 = PythonSensor(task_id='0', python_callable=lambda: False, mode="reschedule")
task1 = PythonSensor(task_id='1', python_callable=lambda: False, mode="reschedule")

ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
ti1 = TI(task=task1, execution_date=DEFAULT_DATE)

dag.create_dagrun(
execution_date=ti0.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)

ti0.run()
ti1.run()

with create_session() as session:

def count_task_reschedule(task_id):
return (
session.query(TaskReschedule)
.filter(
TaskReschedule.dag_id == dag.dag_id,
TaskReschedule.task_id == task_id,
TaskReschedule.execution_date == DEFAULT_DATE,
TaskReschedule.try_number == 1,
)
.count()
)

assert count_task_reschedule(ti0.task_id) == 1
assert count_task_reschedule(ti1.task_id) == 1
qry = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id).all()
clear_task_instances(qry, session, dag=dag)
assert count_task_reschedule(ti0.task_id) == 0
assert count_task_reschedule(ti1.task_id) == 1

def test_dag_clear(self):
dag = DAG(
'test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)
Expand Down

0 comments on commit 118f86c

Please sign in to comment.