diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c743d428ea8560..fec70cc3a501bf 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -24,6 +24,7 @@ import pickle import signal import warnings +from collections import defaultdict from datetime import datetime, timedelta from tempfile import NamedTemporaryFile from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union @@ -146,6 +147,7 @@ def clear_task_instances( :param dag: DAG object """ job_ids = [] + task_id_by_key = defaultdict(lambda: defaultdict(lambda: defaultdict(set))) for ti in tis: if ti.state == State.RUNNING: if ti.job_id: @@ -166,13 +168,36 @@ def clear_task_instances( ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries) ti.state = State.NONE session.merge(ti) + + task_id_by_key[ti.dag_id][ti.execution_date][ti.try_number].add(ti.task_id) + + if task_id_by_key: # Clear all reschedules related to the ti to clear - session.query(TR).filter( - TR.dag_id == ti.dag_id, - TR.task_id == ti.task_id, - TR.execution_date == ti.execution_date, - TR.try_number == ti.try_number, - ).delete() + + # This is an optimization for the common case where all tis are for a small number + # of dag_id, execution_date and try_number. Use a nested dict of dag_id, + # execution_date, try_number and task_id to construct the where clause in a + # hierarchical manner. This speeds up the delete statement by more than 40x for + # large number of tis (50k+). + conditions = or_( + and_( + TR.dag_id == dag_id, + or_( + and_( + TR.execution_date == execution_date, + or_( + and_(TR.try_number == try_number, TR.task_id.in_(task_ids)) + for try_number, task_ids in task_tries.items() + ), + ) + for execution_date, task_tries in dates.items() + ), + ) + for dag_id, dates in task_id_by_key.items() + ) + + delete_qry = TR.__table__.delete().where(conditions) + session.execute(delete_qry) if job_ids: from airflow.jobs.base_job import BaseJob diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py index f54bacce339f47..1c5606e66acd84 100644 --- a/tests/models/test_cleartasks.py +++ b/tests/models/test_cleartasks.py @@ -20,8 +20,9 @@ import unittest from airflow import settings -from airflow.models import DAG, TaskInstance as TI, clear_task_instances +from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances from airflow.operators.dummy import DummyOperator +from airflow.sensors.python import PythonSensor from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -138,6 +139,50 @@ def test_clear_task_instances_without_dag(self): assert ti1.try_number == 2 assert ti1.max_tries == 2 + def test_clear_task_instances_with_task_reschedule(self): + """Test that TaskReschedules are deleted correctly when TaskInstances are cleared""" + + with DAG( + 'test_clear_task_instances_with_task_reschedule', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) as dag: + task0 = PythonSensor(task_id='0', python_callable=lambda: False, mode="reschedule") + task1 = PythonSensor(task_id='1', python_callable=lambda: False, mode="reschedule") + + ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + + dag.create_dagrun( + execution_date=ti0.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + + ti0.run() + ti1.run() + + with create_session() as session: + + def count_task_reschedule(task_id): + return ( + session.query(TaskReschedule) + .filter( + TaskReschedule.dag_id == dag.dag_id, + TaskReschedule.task_id == task_id, + TaskReschedule.execution_date == DEFAULT_DATE, + TaskReschedule.try_number == 1, + ) + .count() + ) + + assert count_task_reschedule(ti0.task_id) == 1 + assert count_task_reschedule(ti1.task_id) == 1 + qry = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id).all() + clear_task_instances(qry, session, dag=dag) + assert count_task_reschedule(ti0.task_id) == 0 + assert count_task_reschedule(ti1.task_id) == 1 + def test_dag_clear(self): dag = DAG( 'test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)