diff --git a/UPDATING.md b/UPDATING.md index 2002cdcfea481f..3b226e07a6adbc 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -23,6 +23,15 @@ assists users migrating to a new version. ## Airflow Master +### Removal of `non_pooled_task_slot_count` and `non_pooled_backfill_task_slot_count` + +`non_pooled_task_slot_count` and `non_pooled_backfill_task_slot_count` +are removed in favor of a real pool, e.g. `default_pool`. + +By default tasks are running in `default_pool`. +`default_pool` is initialized with 128 slots and user can change the +number of slots through UI/CLI. `default_pool` cannot be removed. + ### Changes to Google Transfer Operator To obtain pylint compatibility the `filter ` argument in `GcpTransferServiceOperationsListOperator` has been renamed to `request_filter`. diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py index 22b458ed8bd43f..34763028c5e184 100644 --- a/airflow/api/common/experimental/pool.py +++ b/airflow/api/common/experimental/pool.py @@ -72,6 +72,9 @@ def delete_pool(name, session=None): if not (name and name.strip()): raise AirflowBadRequest("Pool name shouldn't be empty") + if name == Pool.DEFAULT_POOL_NAME: + raise AirflowBadRequest("default_pool cannot be deleted") + pool = session.query(Pool).filter_by(pool=name).first() if pool is None: raise PoolNotFound("Pool '%s' doesn't exist" % name) diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index da8ad93b5162b8..bc41a10ec559f8 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -132,14 +132,6 @@ dag_concurrency = 16 # Are DAGs paused by default at creation dags_are_paused_at_creation = True -# When not using pools, tasks are run in the "default pool", -# whose size is guided by this config element -non_pooled_task_slot_count = 128 - -# When not using pools, the number of backfill tasks per backfill -# is limited by this config element -non_pooled_backfill_task_slot_count = %(non_pooled_task_slot_count)s - # The maximum number of active DAG runs per DAG max_active_runs_per_dag = 16 diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index aca8008685f6ba..0abf9800ed48fe 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -46,7 +46,6 @@ donot_pickle = False dag_concurrency = 16 dags_are_paused_at_creation = False fernet_key = {FERNET_KEY} -non_pooled_task_slot_count = 128 enable_xcom_pickling = False killed_task_cleanup_time = 5 secure_mode = False diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index 34aa6bdf6eb31a..680998599f0de4 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -23,7 +23,6 @@ from sqlalchemy.orm.session import make_transient -from airflow import configuration as conf from airflow import executors, models from airflow.exceptions import ( AirflowException, @@ -542,32 +541,24 @@ def _per_task_process(task, key, ti, session=None): self.log.debug('Adding %s to not_ready', ti) ti_status.not_ready.add(key) - non_pool_slots = conf.getint('core', 'non_pooled_backfill_task_slot_count') - try: for task in self.dag.topological_sort(): for key, ti in list(ti_status.to_run.items()): if task.task_id != ti.task_id: continue - if task.pool: - pool = session.query(models.Pool) \ - .filter(models.Pool.pool == task.pool) \ - .first() - if not pool: - raise PoolNotFound('Unknown pool: {}'.format(task.pool)) - - open_slots = pool.open_slots(session=session) - if open_slots <= 0: - raise NoAvailablePoolSlot( - "Not scheduling since there are " - "%s open slots in pool %s".format( - open_slots, task.pool)) - else: - if non_pool_slots <= 0: - raise NoAvailablePoolSlot( - "Not scheduling since there are no " - "non_pooled_backfill_task_slot_count.") - non_pool_slots -= 1 + + pool = session.query(models.Pool) \ + .filter(models.Pool.pool == task.pool) \ + .first() + if not pool: + raise PoolNotFound('Unknown pool: {}'.format(task.pool)) + + open_slots = pool.open_slots(session=session) + if open_slots <= 0: + raise NoAvailablePoolSlot( + "Not scheduling since there are " + "%s open slots in pool %s".format( + open_slots, task.pool)) num_running_task_instances_in_dag = DAG.get_num_task_instances( self.dag_id, diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 7d166c73935c41..0c2f64795302f0 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -893,21 +893,14 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None): # any open slots in the pool. for pool, task_instances in pool_to_task_instances.items(): pool_name = pool - if not pool: - # Arbitrary: - # If queued outside of a pool, trigger no more than - # non_pooled_task_slot_count - open_slots = models.Pool.default_pool_open_slots() - pool_name = models.Pool.default_pool_name + if pool not in pools: + self.log.warning( + "Tasks using non-existent pool '%s' will not be scheduled", + pool + ) + open_slots = 0 else: - if pool not in pools: - self.log.warning( - "Tasks using non-existent pool '%s' will not be scheduled", - pool - ) - open_slots = 0 - else: - open_slots = pools[pool].open_slots(session=session) + open_slots = pools[pool].open_slots(session=session) num_ready = len(task_instances) self.log.info( diff --git a/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py new file mode 100644 index 00000000000000..959d04e7a5dbda --- /dev/null +++ b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Make TaskInstance.pool not nullable + +Revision ID: 6e96a59344a4 +Revises: 939bb1e647c8 +Create Date: 2019-06-13 21:51:32.878437 + +""" + +from alembic import op +import dill +import sqlalchemy as sa +from sqlalchemy import Column, Float, Integer, PickleType, String +from sqlalchemy.ext.declarative import declarative_base + +from airflow.utils.db import create_session +from airflow.utils.sqlalchemy import UtcDateTime + + +# revision identifiers, used by Alembic. +revision = '6e96a59344a4' +down_revision = '939bb1e647c8' +branch_labels = None +depends_on = None + + +Base = declarative_base() +ID_LEN = 250 + + +class TaskInstance(Base): + """ + Task instances store the state of a task instance. This table is the + authority and single source of truth around what tasks have run and the + state they are in. + + The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or + dag model deliberately to have more control over transactions. + + Database transactions on this table should insure double triggers and + any confusion around what task instances are or aren't ready to run + even while multiple schedulers may be firing task instances. + """ + + __tablename__ = "task_instance" + + task_id = Column(String(ID_LEN), primary_key=True) + dag_id = Column(String(ID_LEN), primary_key=True) + execution_date = Column(UtcDateTime, primary_key=True) + start_date = Column(UtcDateTime) + end_date = Column(UtcDateTime) + duration = Column(Float) + state = Column(String(20)) + _try_number = Column('try_number', Integer, default=0) + max_tries = Column(Integer) + hostname = Column(String(1000)) + unixname = Column(String(1000)) + job_id = Column(Integer) + pool = Column(String(50), nullable=False) + queue = Column(String(256)) + priority_weight = Column(Integer) + operator = Column(String(1000)) + queued_dttm = Column(UtcDateTime) + pid = Column(Integer) + executor_config = Column(PickleType(pickler=dill)) + + +def upgrade(): + """ + Make TaskInstance.pool field not nullable. + """ + with create_session() as session: + session.query(TaskInstance)\ + .filter(TaskInstance.pool.is_(None))\ + .update({TaskInstance.pool: 'default_pool'}, + synchronize_session=False) # Avoid select updated rows + session.commit() + + # use batch_alter_table to support SQLite workaround + with op.batch_alter_table('task_instance') as batch_op: + batch_op.alter_column( + column_name='pool', + type_=sa.String(50), + nullable=False, + ) + + +def downgrade(): + """ + Make TaskInstance.pool field nullable. + """ + # use batch_alter_table to support SQLite workaround + with op.batch_alter_table('task_instance') as batch_op: + batch_op.alter_column( + column_name='pool', + type_=sa.String(50), + nullable=True, + ) + + with create_session() as session: + session.query(TaskInstance)\ + .filter(TaskInstance.pool == 'default_pool')\ + .update({TaskInstance.pool: None}, + synchronize_session=False) # Avoid select updated rows + session.commit() diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index f724849d81b90b..31114c04d41d2e 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -34,6 +34,7 @@ from airflow.exceptions import AirflowException from airflow.lineage import prepare_lineage, apply_lineage, DataSet from airflow.models.dag import DAG +from airflow.models.pool import Pool from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.xcom import XCOM_RETURN_KEY from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep @@ -258,7 +259,7 @@ def __init__( priority_weight: int = 1, weight_rule: str = WeightRule.DOWNSTREAM, queue: str = configuration.conf.get('celery', 'default_queue'), - pool: Optional[str] = None, + pool: str = Pool.DEFAULT_POOL_NAME, sla: Optional[timedelta] = None, execution_timeout: Optional[timedelta] = None, on_failure_callback: Optional[Callable] = None, diff --git a/airflow/models/pool.py b/airflow/models/pool.py index a7ecebf3a2c8d6..e2405a2c5efca6 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -19,7 +19,6 @@ from sqlalchemy import Column, Integer, String, Text, func -from airflow import conf from airflow.models.base import Base from airflow.utils.state import State from airflow.utils.db import provide_session @@ -33,20 +32,20 @@ class Pool(Base): slots = Column(Integer, default=0) description = Column(Text) - default_pool_name = 'not_pooled' + DEFAULT_POOL_NAME = 'default_pool' def __repr__(self): return self.pool @staticmethod @provide_session - def default_pool_open_slots(session): - from airflow.models import TaskInstance as TI # To avoid circular imports - total_slots = conf.getint('core', 'non_pooled_task_slot_count') - used_slots = session.query(func.count()).filter( - TI.pool == Pool.default_pool_name).filter( - TI.state.in_([State.RUNNING, State.QUEUED])).scalar() - return total_slots - used_slots + def get_pool(pool_name, session=None): + return session.query(Pool).filter(Pool.pool == pool_name).first() + + @staticmethod + @provide_session + def get_default_pool(session=None): + return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session) def to_json(self): return { diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c40d81160aecef..af8ca6377a5d46 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -140,7 +140,7 @@ class TaskInstance(Base, LoggingMixin): hostname = Column(String(1000)) unixname = Column(String(1000)) job_id = Column(Integer) - pool = Column(String(50)) + pool = Column(String(50), nullable=False) queue = Column(String(256)) priority_weight = Column(Integer) operator = Column(String(1000)) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 4de470261e9912..5a86a56b75f91b 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -23,6 +23,7 @@ import contextlib from airflow import settings +from airflow.configuration import conf from airflow.utils.log.logging_mixin import LoggingMixin log = LoggingMixin().log @@ -78,6 +79,20 @@ def merge_conn(conn, session=None): session.commit() +@provide_session +def add_default_pool_if_not_exists(session=None): + from airflow.models.pool import Pool + if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session): + default_pool = Pool( + pool=Pool.DEFAULT_POOL_NAME, + slots=conf.getint(section='core', key='non_pooled_task_slot_count', + fallback=128), + description="Default pool", + ) + session.add(default_pool) + session.commit() + + def initdb(): from airflow import models from airflow.models import Connection @@ -311,6 +326,7 @@ def upgradedb(): config.set_main_option('script_location', directory.replace('%', '%%')) config.set_main_option('sqlalchemy.url', settings.SQL_ALCHEMY_CONN.replace('%', '%%')) command.upgrade(config, 'heads') + add_default_pool_if_not_exists() def resetdb(): diff --git a/airflow/www/views.py b/airflow/www/views.py index 3a6cc8bb92aee4..0d2c6110ee790f 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -2109,6 +2109,10 @@ class PoolModelView(AirflowModelView): @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', single=False) def action_muldelete(self, items): + if any(item.pool == models.Pool.DEFAULT_POOL_NAME for item in items): + flash("default_pool cannot be deleted", 'error') + self.update_redirect() + return redirect(self.get_redirect()) self.datamodel.delete_all(items) self.update_redirect() return redirect(self.get_redirect()) diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py index 56fb9e90b528c9..0050fd5ac743f0 100644 --- a/tests/api/client/test_local_client.py +++ b/tests/api/client/test_local_client.py @@ -124,18 +124,19 @@ def test_get_pools(self): self.client.create_pool(name='foo1', slots=1, description='') self.client.create_pool(name='foo2', slots=2, description='') pools = sorted(self.client.get_pools(), key=lambda p: p[0]) - self.assertEqual(pools, [('foo1', 1, ''), ('foo2', 2, '')]) + self.assertEqual(pools, [('default_pool', 128, 'Default pool'), + ('foo1', 1, ''), ('foo2', 2, '')]) def test_create_pool(self): pool = self.client.create_pool(name='foo', slots=1, description='') self.assertEqual(pool, ('foo', 1, '')) with create_session() as session: - self.assertEqual(session.query(models.Pool).count(), 1) + self.assertEqual(session.query(models.Pool).count(), 2) def test_delete_pool(self): self.client.create_pool(name='foo', slots=1, description='') with create_session() as session: - self.assertEqual(session.query(models.Pool).count(), 1) + self.assertEqual(session.query(models.Pool).count(), 2) self.client.delete_pool(name='foo') with create_session() as session: - self.assertEqual(session.query(models.Pool).count(), 0) + self.assertEqual(session.query(models.Pool).count(), 1) diff --git a/tests/api/common/experimental/test_pool.py b/tests/api/common/experimental/test_pool.py index d1be1702e6c26f..1a639b84dabef2 100644 --- a/tests/api/common/experimental/test_pool.py +++ b/tests/api/common/experimental/test_pool.py @@ -20,6 +20,7 @@ import unittest from airflow import models +from airflow.models.pool import Pool from airflow.api.common.experimental import pool as pool_api from airflow.exceptions import AirflowBadRequest, PoolNotFound from airflow.utils.db import create_session @@ -28,9 +29,13 @@ class TestPool(unittest.TestCase): + USER_POOL_COUNT = 2 + TOTAL_POOL_COUNT = USER_POOL_COUNT + 1 # including default_pool + def setUp(self): - self.pools = [] - for i in range(2): + clear_db_pools() + self.pools = [Pool.get_default_pool()] + for i in range(self.USER_POOL_COUNT): name = 'experimental_%s' % (i + 1) pool = models.Pool( pool=name, @@ -41,9 +46,6 @@ def setUp(self): with create_session() as session: session.add_all(self.pools) - def tearDown(self): - clear_db_pools() - def test_get_pool(self): pool = pool_api.get_pool(name=self.pools[0].pool) self.assertEqual(pool.pool, self.pools[0].pool) @@ -75,7 +77,7 @@ def test_create_pool(self): self.assertEqual(pool.slots, 5) self.assertEqual(pool.description, '') with create_session() as session: - self.assertEqual(session.query(models.Pool).count(), 3) + self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT + 1) def test_create_pool_existing(self): pool = pool_api.create_pool(name=self.pools[0].pool, @@ -85,7 +87,7 @@ def test_create_pool_existing(self): self.assertEqual(pool.slots, 5) self.assertEqual(pool.description, '') with create_session() as session: - self.assertEqual(session.query(models.Pool).count(), 2) + self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT) def test_create_pool_bad_name(self): for name in ('', ' '): @@ -105,10 +107,10 @@ def test_create_pool_bad_slots(self): description='') def test_delete_pool(self): - pool = pool_api.delete_pool(name=self.pools[0].pool) - self.assertEqual(pool.pool, self.pools[0].pool) + pool = pool_api.delete_pool(name=self.pools[-1].pool) + self.assertEqual(pool.pool, self.pools[-1].pool) with create_session() as session: - self.assertEqual(session.query(models.Pool).count(), 1) + self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT - 1) def test_delete_pool_non_existing(self): self.assertRaisesRegex(pool_api.PoolNotFound, @@ -123,6 +125,11 @@ def test_delete_pool_bad_name(self): pool_api.delete_pool, name=name) + def test_delete_default_pool_not_allowed(self): + with self.assertRaisesRegex(AirflowBadRequest, + "^default_pool cannot be deleted$"): + pool_api.delete_pool(Pool.DEFAULT_POOL_NAME) + if __name__ == '__main__': unittest.main() diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 57033fbb55e1ea..476397d7b5f3f4 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -41,7 +41,7 @@ from tests.compat import Mock, patch from tests.executors.test_executor import TestExecutor from tests.test_utils.db import clear_db_pools, \ - clear_db_runs + clear_db_runs, set_default_pool_slots configuration.load_test_config() @@ -52,7 +52,7 @@ class BackfillJobTest(unittest.TestCase): - def _get_dummy_dag(self, dag_id, pool=None, task_concurrency=None): + def _get_dummy_dag(self, dag_id, pool=Pool.DEFAULT_POOL_NAME, task_concurrency=None): dag = DAG( dag_id=dag_id, start_date=DEFAULT_DATE, @@ -397,18 +397,9 @@ def test_backfill_respect_dag_concurrency_limit(self, mock_log): self.assertGreater(times_dag_concurrency_limit_reached_in_debug, 0) @patch('airflow.jobs.backfill_job.BackfillJob.log') - @patch('airflow.jobs.backfill_job.conf.getint') - def test_backfill_with_no_pool_limit(self, mock_getint, mock_log): - non_pooled_backfill_task_slot_count = 2 - - def getint(section, key): - if section.lower() == 'core' and \ - 'non_pooled_backfill_task_slot_count' == key.lower(): - return non_pooled_backfill_task_slot_count - else: - return configuration.conf.getint(section, key) - - mock_getint.side_effect = getint + def test_backfill_respect_default_pool_limit(self, mock_log): + default_pool_slots = 2 + set_default_pool_slots(default_pool_slots) dag = self._get_dummy_dag('test_backfill_with_no_pool_limit') @@ -425,24 +416,24 @@ def getint(section, key): self.assertTrue(0 < len(executor.history)) - non_pooled_task_slot_count_reached_at_least_once = False + default_pool_task_slot_count_reached_at_least_once = False num_running_task_instances = 0 # if no pool is specified, the number of tasks running in # parallel per backfill should be less than - # non_pooled_backfill_task_slot_count at any point of time. + # default_pool slots at any point of time. for running_task_instances in executor.history: self.assertLessEqual( len(running_task_instances), - non_pooled_backfill_task_slot_count, + default_pool_slots, ) num_running_task_instances += len(running_task_instances) - if len(running_task_instances) == non_pooled_backfill_task_slot_count: - non_pooled_task_slot_count_reached_at_least_once = True + if len(running_task_instances) == default_pool_slots: + default_pool_task_slot_count_reached_at_least_once = True - self.assertEqual(8, num_running_task_instances) - self.assertTrue(non_pooled_task_slot_count_reached_at_least_once) + self.assertEquals(8, num_running_task_instances) + self.assertTrue(default_pool_task_slot_count_reached_at_least_once) times_dag_concurrency_limit_reached_in_debug = self._times_called_with( mock_log.debug, diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 6dcf697f380ce6..95f9f6a2ba03af 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -47,8 +47,7 @@ from tests.core import TEST_DAG_FOLDER from tests.executors.test_executor import TestExecutor from tests.test_utils.db import clear_db_dags, clear_db_errors, clear_db_pools, \ - clear_db_runs, clear_db_sla_miss -from tests.test_utils.decorators import mock_conf_get + clear_db_runs, clear_db_sla_miss, set_default_pool_slots configuration.load_test_config() @@ -354,9 +353,10 @@ def test_find_executable_task_instances_pool(self): self.assertIn(tis[1].key, res_keys) self.assertIn(tis[3].key, res_keys) - @mock_conf_get('core', 'non_pooled_task_slot_count', 1) - def test_find_executable_task_instances_in_non_pool(self): - dag_id = 'SchedulerJobTest.test_find_executable_task_instances_in_non_pool' + def test_find_executable_task_instances_in_default_pool(self): + set_default_pool_slots(1) + + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_in_default_pool' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) t1 = DummyOperator(dag=dag, task_id='dummy1') t2 = DummyOperator(dag=dag, task_id='dummy2') @@ -366,18 +366,18 @@ def test_find_executable_task_instances_in_non_pool(self): scheduler = SchedulerJob(executor=executor) dr1 = scheduler.create_dag_run(dag) dr2 = scheduler.create_dag_run(dag) - session = settings.Session() ti1 = TI(task=t1, execution_date=dr1.execution_date) ti2 = TI(task=t2, execution_date=dr2.execution_date) ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED + session = settings.Session() session.merge(ti1) session.merge(ti2) session.commit() - # Two tasks w/o pool up for execution and our non_pool size is 1 + # Two tasks w/o pool up for execution and our default pool size is 1 res = scheduler._find_executable_task_instances( dagbag, states=(State.SCHEDULED,), @@ -385,7 +385,6 @@ def test_find_executable_task_instances_in_non_pool(self): self.assertEqual(1, len(res)) ti2.state = State.RUNNING - ti2.pool = Pool.default_pool_name session.merge(ti2) session.commit() diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py index cff4de6fd98442..39c0f3c74f0ce9 100644 --- a/tests/models/test_pool.py +++ b/tests/models/test_pool.py @@ -26,14 +26,17 @@ from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from airflow.utils.state import State -from tests.test_utils.db import clear_db_pools, clear_db_runs -from tests.test_utils.decorators import mock_conf_get +from tests.test_utils.db import clear_db_pools, clear_db_runs, set_default_pool_slots DEFAULT_DATE = timezone.datetime(2016, 1, 1) class PoolTest(unittest.TestCase): + def setUp(self): + clear_db_runs() + clear_db_pools() + def tearDown(self): clear_db_runs() clear_db_pools() @@ -59,8 +62,10 @@ def test_open_slots(self): self.assertEqual(3, pool.open_slots()) - @mock_conf_get('core', 'non_pooled_task_slot_count', 5) def test_default_pool_open_slots(self): + set_default_pool_slots(5) + self.assertEqual(5, Pool.get_default_pool().open_slots()) + dag = DAG( dag_id='test_default_pool_open_slots', start_date=DEFAULT_DATE, ) @@ -70,8 +75,6 @@ def test_default_pool_open_slots(self): ti2 = TI(task=t2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED - ti1.pool = Pool.default_pool_name - ti2.pool = Pool.default_pool_name session = settings.Session session.add(ti1) @@ -79,4 +82,4 @@ def test_default_pool_open_slots(self): session.commit() session.close() - self.assertEqual(3, Pool.default_pool_open_slots()) + self.assertEqual(3, Pool.get_default_pool().open_slots()) diff --git a/tests/test_impersonation.py b/tests/test_impersonation.py index e4a9d90e7a486b..b06c5df035eef6 100644 --- a/tests/test_impersonation.py +++ b/tests/test_impersonation.py @@ -24,6 +24,7 @@ import logging from airflow import jobs, models +from airflow.utils.db import add_default_pool_if_not_exists from airflow.utils.state import State from airflow.utils.timezone import datetime @@ -45,6 +46,7 @@ class ImpersonationTest(unittest.TestCase): def setUp(self): + add_default_pool_if_not_exists() self.dagbag = models.DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py index 998e7cc0f8067f..c0f4b927c968d5 100644 --- a/tests/test_utils/db.py +++ b/tests/test_utils/db.py @@ -17,6 +17,7 @@ # specific language governing permissions and limitations # under the License. from airflow.models import DagModel, DagRun, errors, Pool, SlaMiss, TaskInstance +from airflow.utils.db import add_default_pool_if_not_exists from airflow.utils.db import create_session @@ -44,3 +45,10 @@ def clear_db_errors(): def clear_db_pools(): with create_session() as session: session.query(Pool).delete() + add_default_pool_if_not_exists(session) + + +def set_default_pool_slots(slots): + with create_session() as session: + default_pool = Pool.get_default_pool(session) + default_pool.slots = slots diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py index ee39324c9279a2..0d32cfd872f439 100644 --- a/tests/www/api/experimental/test_endpoints.py +++ b/tests/www/api/experimental/test_endpoints.py @@ -30,6 +30,7 @@ from airflow.settings import Session from airflow.utils.timezone import datetime, utcnow, parse as parse_datetime from airflow.www import app as application +from tests.test_utils.db import clear_db_pools class TestBase(unittest.TestCase): @@ -288,19 +289,18 @@ def test_dagrun_status(self): class TestPoolApiExperimental(TestBase): + USER_POOL_COUNT = 2 + TOTAL_POOL_COUNT = USER_POOL_COUNT + 1 # including default_pool + @classmethod def setUpClass(cls): super(TestPoolApiExperimental, cls).setUpClass() - session = Session() - session.query(Pool).delete() - session.commit() - session.close() def setUp(self): super().setUp() - - self.pools = [] - for i in range(2): + clear_db_pools() + self.pools = [Pool.get_default_pool()] + for i in range(self.USER_POOL_COUNT): name = 'experimental_%s' % (i + 1) pool = Pool( pool=name, @@ -310,12 +310,9 @@ def setUp(self): self.session.add(pool) self.pools.append(pool) self.session.commit() - self.pool = self.pools[0] + self.pool = self.pools[-1] def tearDown(self): - self.session.query(Pool).delete() - self.session.commit() - self.session.close() super().tearDown() def _get_pool_count(self): @@ -341,7 +338,7 @@ def test_get_pools(self): response = self.client.get('/api/experimental/pools') self.assertEqual(response.status_code, 200) pools = json.loads(response.data.decode('utf-8')) - self.assertEqual(len(pools), 2) + self.assertEqual(len(pools), self.TOTAL_POOL_COUNT) for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])): self.assertDictEqual(pool, self.pools[i].to_json()) @@ -360,7 +357,7 @@ def test_create_pool(self): self.assertEqual(pool['pool'], 'foo') self.assertEqual(pool['slots'], 1) self.assertEqual(pool['description'], '') - self.assertEqual(self._get_pool_count(), 3) + self.assertEqual(self._get_pool_count(), self.TOTAL_POOL_COUNT + 1) def test_create_pool_with_bad_name(self): for name in ('', ' '): @@ -378,7 +375,7 @@ def test_create_pool_with_bad_name(self): json.loads(response.data.decode('utf-8'))['error'], "Pool name shouldn't be empty", ) - self.assertEqual(self._get_pool_count(), 2) + self.assertEqual(self._get_pool_count(), self.TOTAL_POOL_COUNT) def test_delete_pool(self): response = self.client.delete( @@ -387,7 +384,7 @@ def test_delete_pool(self): self.assertEqual(response.status_code, 200) self.assertEqual(json.loads(response.data.decode('utf-8')), self.pool.to_json()) - self.assertEqual(self._get_pool_count(), 1) + self.assertEqual(self._get_pool_count(), self.TOTAL_POOL_COUNT - 1) def test_delete_pool_non_existing(self): response = self.client.delete( @@ -397,6 +394,15 @@ def test_delete_pool_non_existing(self): self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "Pool 'foo' doesn't exist") + def test_delete_default_pool(self): + clear_db_pools() + response = self.client.delete( + '/api/experimental/pools/default_pool', + ) + self.assertEqual(response.status_code, 400) + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], + "default_pool cannot be deleted") + if __name__ == '__main__': unittest.main()