From 5ceb0ffa1d33fa73bfae8db62e08d28282855c4c Mon Sep 17 00:00:00 2001 From: Alex Guziel Date: Thu, 7 Nov 2019 13:52:16 -0800 Subject: [PATCH] [AIRFLOW-5870] Allow -1 for infinite pool size --- airflow/models/pool.py | 20 ++++++++++++-------- tests/models/test_pool.py | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/airflow/models/pool.py b/airflow/models/pool.py index 1cb31182383156..1b2686fc924d80 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -from sqlalchemy import Column, Integer, String, Text +from sqlalchemy import Column, Integer, String, Text, func from airflow.models.base import Base from airflow.ti_deps.deps.pool_slots_available_dep import STATES_TO_COUNT_AS_RUNNING @@ -30,6 +30,7 @@ class Pool(Base): id = Column(Integer, primary_key=True) pool = Column(String(50), unique=True) + # -1 for infinite slots = Column(Integer, default=0) description = Column(Text) @@ -64,10 +65,10 @@ def occupied_slots(self, session): from airflow.models.taskinstance import TaskInstance # Avoid circular import return ( session - .query(TaskInstance) + .query(func.count()) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state.in_(STATES_TO_COUNT_AS_RUNNING)) - .count() + .scalar() ) @provide_session @@ -79,10 +80,10 @@ def used_slots(self, session): running = ( session - .query(TaskInstance) + .query(func.count()) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state == State.RUNNING) - .count() + .scalar() ) return running @@ -95,10 +96,10 @@ def queued_slots(self, session): return ( session - .query(TaskInstance) + .query(func.count()) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state == State.QUEUED) - .count() + .scalar() ) @provide_session @@ -106,4 +107,7 @@ def open_slots(self, session): """ Returns the number of slots open at the moment """ - return self.slots - self.occupied_slots(session) + if self.slots == -1: + return float('inf') + else: + return self.slots - self.occupied_slots(session) diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py index 8559e2374fa4dc..22db8f6eb84527 100644 --- a/tests/models/test_pool.py +++ b/tests/models/test_pool.py @@ -65,6 +65,30 @@ def test_open_slots(self): self.assertEqual(1, pool.queued_slots()) self.assertEqual(2, pool.occupied_slots()) + def test_infinite_slots(self): + pool = Pool(pool='test_pool', slots=-1) + dag = DAG( + dag_id='test_infinite_slots', + start_date=DEFAULT_DATE, ) + t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') + t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') + ti1 = TI(task=t1, execution_date=DEFAULT_DATE) + ti2 = TI(task=t2, execution_date=DEFAULT_DATE) + ti1.state = State.RUNNING + ti2.state = State.QUEUED + + session = settings.Session + session.add(pool) + session.add(ti1) + session.add(ti2) + session.commit() + session.close() + + self.assertEqual(float('inf'), pool.open_slots()) + self.assertEqual(1, pool.used_slots()) + self.assertEqual(1, pool.queued_slots()) + self.assertEqual(2, pool.occupied_slots()) + def test_default_pool_open_slots(self): set_default_pool_slots(5) self.assertEqual(5, Pool.get_default_pool().open_slots())