diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index f7b5cb1511e53..d6447c754f54c 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -811,14 +811,7 @@ def process_file(self, file_path, zombies, pickle_dags=False, session=None): # Save individual DAGs in the ORM and update DagModel.last_scheduled_time dagbag.sync_to_db() - paused_dag_ids = ( - session.query(DagModel.dag_id) - .filter(DagModel.is_paused.is_(True)) - .filter(DagModel.dag_id.in_(dagbag.dag_ids)) - .all() - ) - - paused_dag_ids = set(paused_dag_id for paused_dag_id, in paused_dag_ids) + paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) # Pickle the DAGs (if necessary) and put them into a SimpleDag for dag_id, dag in dagbag.dags.items(): diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 246e1437b0bac..ddfdd5b13198f 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -26,7 +26,7 @@ import traceback from collections import OrderedDict, defaultdict from datetime import datetime, timedelta -from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Type, Union +from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union import jinja2 import pendulum @@ -1777,6 +1777,26 @@ def get_last_dagrun(self, session=None, include_externally_triggered=False): return get_last_dagrun(self.dag_id, session=session, include_externally_triggered=include_externally_triggered) + @staticmethod + @provide_session + def get_paused_dag_ids(dag_ids: List[str], session: Session = None) -> Set[str]: + """ + Given a list of dag_ids, get a set of Paused Dag Ids + + :param dag_ids: List of Dag ids + :param session: ORM Session + :return: Paused Dag_ids + """ + paused_dag_ids = ( + session.query(DagModel.dag_id) + .filter(DagModel.is_paused.is_(True)) + .filter(DagModel.dag_id.in_(dag_ids)) + .all() + ) + + paused_dag_ids = set(paused_dag_id for paused_dag_id, in paused_dag_ids) + return paused_dag_ids + @property def safe_dag_id(self): return self.dag_id.replace('.', '__dot__') diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 8d4ec59acdbb4..e5b64c6f0dcd6 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1346,6 +1346,20 @@ class DAGsubclass(DAG): self.assertNotEqual(hash(dag_diff_name), hash(dag)) self.assertNotEqual(hash(dag_subclass), hash(dag)) + def test_get_paused_dag_ids(self): + dag_id = "test_get_paused_dag_ids" + dag = DAG(dag_id, is_paused_upon_creation=True) + dag.sync_to_db() + self.assertIsNotNone(DagModel.get_dagmodel(dag_id)) + + paused_dag_ids = DagModel.get_paused_dag_ids([dag_id]) + self.assertEqual(paused_dag_ids, {dag_id}) + + with create_session() as session: + session.query(DagModel).filter( + DagModel.dag_id == dag_id).delete( + synchronize_session=False) + class TestQueries(unittest.TestCase):