Skip to content

Commit

Permalink
[AIRFLOW-4420] Backfill respects task_concurrency (apache#5221)
Browse files Browse the repository at this point in the history
Ensure that backfill respects task_concurrency.
That is, the number of concurrent running tasks
across DAG runs should not exceed task_concurrency.
  • Loading branch information
milton0825 authored and wayne.morris committed Jul 29, 2019
1 parent 7e5da60 commit 49cc632
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 18 deletions.
7 changes: 6 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,10 @@ class NoAvailablePoolSlot(AirflowException):


class DagConcurrencyLimitReached(AirflowException):
"""Raise when concurrency limit is reached"""
"""Raise when DAG concurrency limit is reached"""
pass


class TaskConcurrencyLimitReached(AirflowException):
"""Raise when task concurrency limit is reached"""
pass
34 changes: 27 additions & 7 deletions airflow/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@

from airflow import configuration as conf
from airflow import executors, models, settings
from airflow.exceptions import (AirflowException, DagConcurrencyLimitReached,
NoAvailablePoolSlot, PoolNotFound)
from airflow.exceptions import (
AirflowException,
DagConcurrencyLimitReached,
NoAvailablePoolSlot,
PoolNotFound,
TaskConcurrencyLimitReached,
)
from airflow.models import DAG, DagPickle, DagRun, SlaMiss, errors
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
Expand Down Expand Up @@ -1800,6 +1805,7 @@ class BackfillJob(BaseJob):
"""
ID_PREFIX = 'backfill_'
ID_FORMAT_PREFIX = ID_PREFIX + '{0}'
STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED)

__mapper_args__ = {
'polymorphic_identity': 'BackfillJob'
Expand Down Expand Up @@ -2315,18 +2321,32 @@ def _per_task_process(task, key, ti, session=None):
"non_pooled_backfill_task_slot_count.")
non_pool_slots -= 1

num_running_tasks = DAG.get_num_task_instances(
num_running_task_instances_in_dag = DAG.get_num_task_instances(
self.dag_id,
states=(State.QUEUED, State.RUNNING))
states=self.STATES_COUNT_AS_RUNNING,
)

if num_running_tasks >= self.dag.concurrency:
if num_running_task_instances_in_dag >= self.dag.concurrency:
raise DagConcurrencyLimitReached(
"Not scheduling since concurrency limit "
"Not scheduling since DAG concurrency limit "
"is reached."
)

if task.task_concurrency:
num_running_task_instances_in_task = DAG.get_num_task_instances(
dag_id=self.dag_id,
task_ids=[task.task_id],
states=self.STATES_COUNT_AS_RUNNING,
)

if num_running_task_instances_in_task >= task.task_concurrency:
raise TaskConcurrencyLimitReached(
"Not scheduling since Task concurrency limit "
"is reached."
)

_per_task_process(task, key, ti)
except (NoAvailablePoolSlot, DagConcurrencyLimitReached) as e:
except (NoAvailablePoolSlot, DagConcurrencyLimitReached, TaskConcurrencyLimitReached) as e:
self.log.debug(e)

# execute the tasks in the queue
Expand Down
92 changes: 82 additions & 10 deletions tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from airflow import AirflowException, models, settings
from airflow import configuration
from airflow.bin import cli
from airflow.exceptions import DagConcurrencyLimitReached, NoAvailablePoolSlot
from airflow.exceptions import DagConcurrencyLimitReached, NoAvailablePoolSlot, \
TaskConcurrencyLimitReached
from airflow.executors import BaseExecutor, SequentialExecutor
from airflow.jobs import BackfillJob, BaseJob, LocalTaskJob, SchedulerJob
from airflow.models import DAG, DagBag, DagModel, DagRun, Pool, SlaMiss, \
Expand Down Expand Up @@ -123,7 +124,7 @@ def abort():

class BackfillJobTest(unittest.TestCase):

def _get_dummy_dag(self, dag_id, pool=None):
def _get_dummy_dag(self, dag_id, pool=None, task_concurrency=None):
dag = DAG(
dag_id=dag_id,
start_date=DEFAULT_DATE,
Expand All @@ -133,6 +134,7 @@ def _get_dummy_dag(self, dag_id, pool=None):
DummyOperator(
task_id='op',
pool=pool,
task_concurrency=task_concurrency,
dag=dag)

dag.clear()
Expand Down Expand Up @@ -364,9 +366,61 @@ def test_backfill_conf(self):
self.assertEqual(conf, dr[0].conf)

@patch('airflow.jobs.LoggingMixin.log')
def test_backfill_respect_concurrency_limit(self, mock_log):
def test_backfill_respect_task_concurrency_limit(self, mock_log):
task_concurrency = 2
dag = self._get_dummy_dag(
'test_backfill_respect_task_concurrency_limit',
task_concurrency=task_concurrency,
)

executor = TestExecutor()

job = BackfillJob(
dag=dag,
executor=executor,
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=7),
)

job.run()

self.assertTrue(0 < len(executor.history))

task_concurrency_limit_reached_at_least_once = False

num_running_task_instances = 0
for running_task_instances in executor.history:
self.assertLessEqual(len(running_task_instances), task_concurrency)
num_running_task_instances += len(running_task_instances)
if len(running_task_instances) == task_concurrency:
task_concurrency_limit_reached_at_least_once = True

self.assertEquals(8, num_running_task_instances)
self.assertTrue(task_concurrency_limit_reached_at_least_once)

dag = self._get_dummy_dag('test_backfill_respect_concurrency_limit')
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
DagConcurrencyLimitReached,
)

times_pool_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
NoAvailablePoolSlot,
)

times_task_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
TaskConcurrencyLimitReached,
)

self.assertEquals(0, times_pool_limit_reached_in_debug)
self.assertEquals(0, times_dag_concurrency_limit_reached_in_debug)
self.assertGreater(times_task_concurrency_limit_reached_in_debug, 0)

@patch('airflow.jobs.LoggingMixin.log')
def test_backfill_respect_dag_concurrency_limit(self, mock_log):

dag = self._get_dummy_dag('test_backfill_respect_dag_concurrency_limit')
dag.concurrency = 2

executor = TestExecutor()
Expand Down Expand Up @@ -395,7 +449,7 @@ def test_backfill_respect_concurrency_limit(self, mock_log):
self.assertEquals(8, num_running_task_instances)
self.assertTrue(concurrency_limit_reached_at_least_once)

times_concurrency_limit_reached_in_debug = self._times_called_with(
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
DagConcurrencyLimitReached,
)
Expand All @@ -405,8 +459,14 @@ def test_backfill_respect_concurrency_limit(self, mock_log):
NoAvailablePoolSlot,
)

times_task_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
TaskConcurrencyLimitReached,
)

self.assertEquals(0, times_pool_limit_reached_in_debug)
self.assertGreater(times_concurrency_limit_reached_in_debug, 0)
self.assertEquals(0, times_task_concurrency_limit_reached_in_debug)
self.assertGreater(times_dag_concurrency_limit_reached_in_debug, 0)

@patch('airflow.jobs.LoggingMixin.log')
@patch('airflow.jobs.conf.getint')
Expand Down Expand Up @@ -456,7 +516,7 @@ def getint(section, key):
self.assertEquals(8, num_running_task_instances)
self.assertTrue(non_pooled_task_slot_count_reached_at_least_once)

times_concurrency_limit_reached_in_debug = self._times_called_with(
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
DagConcurrencyLimitReached,
)
Expand All @@ -466,7 +526,13 @@ def getint(section, key):
NoAvailablePoolSlot,
)

self.assertEquals(0, times_concurrency_limit_reached_in_debug)
times_task_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
TaskConcurrencyLimitReached,
)

self.assertEquals(0, times_dag_concurrency_limit_reached_in_debug)
self.assertEquals(0, times_task_concurrency_limit_reached_in_debug)
self.assertGreater(times_pool_limit_reached_in_debug, 0)

def test_backfill_pool_not_found(self):
Expand Down Expand Up @@ -533,7 +599,7 @@ def test_backfill_respect_pool_limit(self, mock_log):
self.assertEquals(8, num_running_task_instances)
self.assertTrue(pool_was_full_at_least_once)

times_concurrency_limit_reached_in_debug = self._times_called_with(
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
DagConcurrencyLimitReached,
)
Expand All @@ -543,7 +609,13 @@ def test_backfill_respect_pool_limit(self, mock_log):
NoAvailablePoolSlot,
)

self.assertEquals(0, times_concurrency_limit_reached_in_debug)
times_task_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
TaskConcurrencyLimitReached,
)

self.assertEquals(0, times_task_concurrency_limit_reached_in_debug)
self.assertEquals(0, times_dag_concurrency_limit_reached_in_debug)
self.assertGreater(times_pool_limit_reached_in_debug, 0)

def test_backfill_run_rescheduled(self):
Expand Down

0 comments on commit 49cc632

Please sign in to comment.