diff --git a/airflow/exceptions.py b/airflow/exceptions.py index b1ca2471df54f..5f0b6d65234b3 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -119,5 +119,10 @@ class NoAvailablePoolSlot(AirflowException): class DagConcurrencyLimitReached(AirflowException): - """Raise when concurrency limit is reached""" + """Raise when DAG concurrency limit is reached""" + pass + + +class TaskConcurrencyLimitReached(AirflowException): + """Raise when task concurrency limit is reached""" pass diff --git a/airflow/jobs.py b/airflow/jobs.py index b0aa1d96e69d8..ce42a5f8b8310 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -37,8 +37,13 @@ from airflow import configuration as conf from airflow import executors, models, settings -from airflow.exceptions import (AirflowException, DagConcurrencyLimitReached, - NoAvailablePoolSlot, PoolNotFound) +from airflow.exceptions import ( + AirflowException, + DagConcurrencyLimitReached, + NoAvailablePoolSlot, + PoolNotFound, + TaskConcurrencyLimitReached, +) from airflow.models import DAG, DagPickle, DagRun, SlaMiss, errors from airflow.stats import Stats from airflow.task.task_runner import get_task_runner @@ -1800,6 +1805,7 @@ class BackfillJob(BaseJob): """ ID_PREFIX = 'backfill_' ID_FORMAT_PREFIX = ID_PREFIX + '{0}' + STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED) __mapper_args__ = { 'polymorphic_identity': 'BackfillJob' @@ -2315,18 +2321,32 @@ def _per_task_process(task, key, ti, session=None): "non_pooled_backfill_task_slot_count.") non_pool_slots -= 1 - num_running_tasks = DAG.get_num_task_instances( + num_running_task_instances_in_dag = DAG.get_num_task_instances( self.dag_id, - states=(State.QUEUED, State.RUNNING)) + states=self.STATES_COUNT_AS_RUNNING, + ) - if num_running_tasks >= self.dag.concurrency: + if num_running_task_instances_in_dag >= self.dag.concurrency: raise DagConcurrencyLimitReached( - "Not scheduling since concurrency limit " + "Not scheduling since DAG concurrency limit " "is reached." ) + if task.task_concurrency: + num_running_task_instances_in_task = DAG.get_num_task_instances( + dag_id=self.dag_id, + task_ids=[task.task_id], + states=self.STATES_COUNT_AS_RUNNING, + ) + + if num_running_task_instances_in_task >= task.task_concurrency: + raise TaskConcurrencyLimitReached( + "Not scheduling since Task concurrency limit " + "is reached." + ) + _per_task_process(task, key, ti) - except (NoAvailablePoolSlot, DagConcurrencyLimitReached) as e: + except (NoAvailablePoolSlot, DagConcurrencyLimitReached, TaskConcurrencyLimitReached) as e: self.log.debug(e) # execute the tasks in the queue diff --git a/tests/test_jobs.py b/tests/test_jobs.py index 7740d07ee4932..beb16bf3715d1 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -37,7 +37,8 @@ from airflow import AirflowException, models, settings from airflow import configuration from airflow.bin import cli -from airflow.exceptions import DagConcurrencyLimitReached, NoAvailablePoolSlot +from airflow.exceptions import DagConcurrencyLimitReached, NoAvailablePoolSlot, \ + TaskConcurrencyLimitReached from airflow.executors import BaseExecutor, SequentialExecutor from airflow.jobs import BackfillJob, BaseJob, LocalTaskJob, SchedulerJob from airflow.models import DAG, DagBag, DagModel, DagRun, Pool, SlaMiss, \ @@ -123,7 +124,7 @@ def abort(): class BackfillJobTest(unittest.TestCase): - def _get_dummy_dag(self, dag_id, pool=None): + def _get_dummy_dag(self, dag_id, pool=None, task_concurrency=None): dag = DAG( dag_id=dag_id, start_date=DEFAULT_DATE, @@ -133,6 +134,7 @@ def _get_dummy_dag(self, dag_id, pool=None): DummyOperator( task_id='op', pool=pool, + task_concurrency=task_concurrency, dag=dag) dag.clear() @@ -364,9 +366,61 @@ def test_backfill_conf(self): self.assertEqual(conf, dr[0].conf) @patch('airflow.jobs.LoggingMixin.log') - def test_backfill_respect_concurrency_limit(self, mock_log): + def test_backfill_respect_task_concurrency_limit(self, mock_log): + task_concurrency = 2 + dag = self._get_dummy_dag( + 'test_backfill_respect_task_concurrency_limit', + task_concurrency=task_concurrency, + ) + + executor = TestExecutor() + + job = BackfillJob( + dag=dag, + executor=executor, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=7), + ) + + job.run() + + self.assertTrue(0 < len(executor.history)) + + task_concurrency_limit_reached_at_least_once = False + + num_running_task_instances = 0 + for running_task_instances in executor.history: + self.assertLessEqual(len(running_task_instances), task_concurrency) + num_running_task_instances += len(running_task_instances) + if len(running_task_instances) == task_concurrency: + task_concurrency_limit_reached_at_least_once = True + + self.assertEquals(8, num_running_task_instances) + self.assertTrue(task_concurrency_limit_reached_at_least_once) - dag = self._get_dummy_dag('test_backfill_respect_concurrency_limit') + times_dag_concurrency_limit_reached_in_debug = self._times_called_with( + mock_log.debug, + DagConcurrencyLimitReached, + ) + + times_pool_limit_reached_in_debug = self._times_called_with( + mock_log.debug, + NoAvailablePoolSlot, + ) + + times_task_concurrency_limit_reached_in_debug = self._times_called_with( + mock_log.debug, + TaskConcurrencyLimitReached, + ) + + self.assertEquals(0, times_pool_limit_reached_in_debug) + self.assertEquals(0, times_dag_concurrency_limit_reached_in_debug) + self.assertGreater(times_task_concurrency_limit_reached_in_debug, 0) + + @patch('airflow.jobs.LoggingMixin.log') + def test_backfill_respect_dag_concurrency_limit(self, mock_log): + + dag = self._get_dummy_dag('test_backfill_respect_dag_concurrency_limit') dag.concurrency = 2 executor = TestExecutor() @@ -395,7 +449,7 @@ def test_backfill_respect_concurrency_limit(self, mock_log): self.assertEquals(8, num_running_task_instances) self.assertTrue(concurrency_limit_reached_at_least_once) - times_concurrency_limit_reached_in_debug = self._times_called_with( + times_dag_concurrency_limit_reached_in_debug = self._times_called_with( mock_log.debug, DagConcurrencyLimitReached, ) @@ -405,8 +459,14 @@ def test_backfill_respect_concurrency_limit(self, mock_log): NoAvailablePoolSlot, ) + times_task_concurrency_limit_reached_in_debug = self._times_called_with( + mock_log.debug, + TaskConcurrencyLimitReached, + ) + self.assertEquals(0, times_pool_limit_reached_in_debug) - self.assertGreater(times_concurrency_limit_reached_in_debug, 0) + self.assertEquals(0, times_task_concurrency_limit_reached_in_debug) + self.assertGreater(times_dag_concurrency_limit_reached_in_debug, 0) @patch('airflow.jobs.LoggingMixin.log') @patch('airflow.jobs.conf.getint') @@ -456,7 +516,7 @@ def getint(section, key): self.assertEquals(8, num_running_task_instances) self.assertTrue(non_pooled_task_slot_count_reached_at_least_once) - times_concurrency_limit_reached_in_debug = self._times_called_with( + times_dag_concurrency_limit_reached_in_debug = self._times_called_with( mock_log.debug, DagConcurrencyLimitReached, ) @@ -466,7 +526,13 @@ def getint(section, key): NoAvailablePoolSlot, ) - self.assertEquals(0, times_concurrency_limit_reached_in_debug) + times_task_concurrency_limit_reached_in_debug = self._times_called_with( + mock_log.debug, + TaskConcurrencyLimitReached, + ) + + self.assertEquals(0, times_dag_concurrency_limit_reached_in_debug) + self.assertEquals(0, times_task_concurrency_limit_reached_in_debug) self.assertGreater(times_pool_limit_reached_in_debug, 0) def test_backfill_pool_not_found(self): @@ -533,7 +599,7 @@ def test_backfill_respect_pool_limit(self, mock_log): self.assertEquals(8, num_running_task_instances) self.assertTrue(pool_was_full_at_least_once) - times_concurrency_limit_reached_in_debug = self._times_called_with( + times_dag_concurrency_limit_reached_in_debug = self._times_called_with( mock_log.debug, DagConcurrencyLimitReached, ) @@ -543,7 +609,13 @@ def test_backfill_respect_pool_limit(self, mock_log): NoAvailablePoolSlot, ) - self.assertEquals(0, times_concurrency_limit_reached_in_debug) + times_task_concurrency_limit_reached_in_debug = self._times_called_with( + mock_log.debug, + TaskConcurrencyLimitReached, + ) + + self.assertEquals(0, times_task_concurrency_limit_reached_in_debug) + self.assertEquals(0, times_dag_concurrency_limit_reached_in_debug) self.assertGreater(times_pool_limit_reached_in_debug, 0) def test_backfill_run_rescheduled(self):