Skip to content

Commit

Permalink
[AIRFLOW-5883] Don't use .count() from sqlalchemy to count (#6532)
Browse files Browse the repository at this point in the history
(cherry picked from commit 084d5f8)
  • Loading branch information
saguziel authored and kaxil committed Dec 17, 2019
1 parent 0c96e5d commit bb10ed0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
5 changes: 3 additions & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,15 +669,16 @@ def get_num_active_runs(self, external_trigger=None, session=None):
:param session:
:return: number greater than 0 for active dag runs
"""
# .count() is inefficient
query = (session
.query(DagRun)
.query(func.count())
.filter(DagRun.dag_id == self.dag_id)
.filter(DagRun.state == State.RUNNING))

if external_trigger is not None:
query = query.filter(DagRun.external_trigger == external_trigger)

return query.count()
return query.scalar()

@provide_session
def get_dagrun(self, execution_date, session=None):
Expand Down
5 changes: 3 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,11 +1410,12 @@ def xcom_pull(
@provide_session
def get_num_running_task_instances(self, session):
TI = TaskInstance
return session.query(TI).filter(
# .count() is inefficient
return session.query(func.count()).filter(
TI.dag_id == self.dag_id,
TI.task_id == self.task_id,
TI.state == State.RUNNING
).count()
).scalar()

def init_run_context(self, raw=False):
"""
Expand Down
12 changes: 8 additions & 4 deletions airflow/sensors/external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import os

from sqlalchemy import func

from airflow.exceptions import AirflowException
from airflow.models import TaskInstance, DagBag, DagModel, DagRun
from airflow.sensors.base_sensor_operator import BaseSensorOperator
Expand Down Expand Up @@ -137,18 +139,20 @@ def poke(self, context, session=None):
self.external_dag_id))

if self.external_task_id:
count = session.query(TI).filter(
# .count() is inefficient
count = session.query(func.count()).filter(
TI.dag_id == self.external_dag_id,
TI.task_id == self.external_task_id,
TI.state.in_(self.allowed_states),
TI.execution_date.in_(dttm_filter),
).count()
).scalar()
else:
count = session.query(DR).filter(
# .count() is inefficient
count = session.query(func.count()).filter(
DR.dag_id == self.external_dag_id,
DR.state.in_(self.allowed_states),
DR.execution_date.in_(dttm_filter),
).count()
).scalar()

session.commit()
return count == len(dttm_filter)

0 comments on commit bb10ed0

Please sign in to comment.