Skip to content
This repository has been archived by the owner on May 22, 2021. It is now read-only.

Commit

Permalink
[AIRFLOW-6957] Make retrieving Paused Dag ids a separate method (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored and galuszkak committed Mar 5, 2020
1 parent 96f3c64 commit 1d91738
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
9 changes: 1 addition & 8 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,7 @@ def process_file(self, file_path, zombies, pickle_dags=False, session=None):
# Save individual DAGs in the ORM and update DagModel.last_scheduled_time
dagbag.sync_to_db()

paused_dag_ids = (
session.query(DagModel.dag_id)
.filter(DagModel.is_paused.is_(True))
.filter(DagModel.dag_id.in_(dagbag.dag_ids))
.all()
)

paused_dag_ids = set(paused_dag_id for paused_dag_id, in paused_dag_ids)
paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)

# Pickle the DAGs (if necessary) and put them into a SimpleDag
for dag_id, dag in dagbag.dags.items():
Expand Down
22 changes: 21 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import traceback
from collections import OrderedDict, defaultdict
from datetime import datetime, timedelta
from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Type, Union
from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union

import jinja2
import pendulum
Expand Down Expand Up @@ -1777,6 +1777,26 @@ def get_last_dagrun(self, session=None, include_externally_triggered=False):
return get_last_dagrun(self.dag_id, session=session,
include_externally_triggered=include_externally_triggered)

@staticmethod
@provide_session
def get_paused_dag_ids(dag_ids: List[str], session: Session = None) -> Set[str]:
"""
Given a list of dag_ids, get a set of Paused Dag Ids
:param dag_ids: List of Dag ids
:param session: ORM Session
:return: Paused Dag_ids
"""
paused_dag_ids = (
session.query(DagModel.dag_id)
.filter(DagModel.is_paused.is_(True))
.filter(DagModel.dag_id.in_(dag_ids))
.all()
)

paused_dag_ids = set(paused_dag_id for paused_dag_id, in paused_dag_ids)
return paused_dag_ids

@property
def safe_dag_id(self):
return self.dag_id.replace('.', '__dot__')
Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,20 @@ class DAGsubclass(DAG):
self.assertNotEqual(hash(dag_diff_name), hash(dag))
self.assertNotEqual(hash(dag_subclass), hash(dag))

def test_get_paused_dag_ids(self):
dag_id = "test_get_paused_dag_ids"
dag = DAG(dag_id, is_paused_upon_creation=True)
dag.sync_to_db()
self.assertIsNotNone(DagModel.get_dagmodel(dag_id))

paused_dag_ids = DagModel.get_paused_dag_ids([dag_id])
self.assertEqual(paused_dag_ids, {dag_id})

with create_session() as session:
session.query(DagModel).filter(
DagModel.dag_id == dag_id).delete(
synchronize_session=False)


class TestQueries(unittest.TestCase):

Expand Down

0 comments on commit 1d91738

Please sign in to comment.