Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-4420] Backfill respects task_concurrency #5221

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,10 @@ class NoAvailablePoolSlot(AirflowException):


class DagConcurrencyLimitReached(AirflowException):
"""Raise when concurrency limit is reached"""
"""Raise when DAG concurrency limit is reached"""
pass


class TaskConcurrencyLimitReached(AirflowException):
"""Raise when task concurrency limit is reached"""
pass
34 changes: 27 additions & 7 deletions airflow/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@

from airflow import configuration as conf
from airflow import executors, models, settings
from airflow.exceptions import (AirflowException, DagConcurrencyLimitReached,
NoAvailablePoolSlot, PoolNotFound)
from airflow.exceptions import (
AirflowException,
DagConcurrencyLimitReached,
NoAvailablePoolSlot,
PoolNotFound,
TaskConcurrencyLimitReached,
)
from airflow.models import DAG, DagPickle, DagRun, SlaMiss, errors
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
Expand Down Expand Up @@ -1800,6 +1805,7 @@ class BackfillJob(BaseJob):
"""
ID_PREFIX = 'backfill_'
ID_FORMAT_PREFIX = ID_PREFIX + '{0}'
STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED)

__mapper_args__ = {
'polymorphic_identity': 'BackfillJob'
Expand Down Expand Up @@ -2315,18 +2321,32 @@ def _per_task_process(task, key, ti, session=None):
"non_pooled_backfill_task_slot_count.")
non_pool_slots -= 1

num_running_tasks = DAG.get_num_task_instances(
num_running_task_instances_in_dag = DAG.get_num_task_instances(
self.dag_id,
states=(State.QUEUED, State.RUNNING))
states=self.STATES_COUNT_AS_RUNNING,
)

if num_running_tasks >= self.dag.concurrency:
if num_running_task_instances_in_dag >= self.dag.concurrency:
raise DagConcurrencyLimitReached(
"Not scheduling since concurrency limit "
"Not scheduling since DAG concurrency limit "
"is reached."
)

if task.task_concurrency:
num_running_task_instances_in_task = DAG.get_num_task_instances(
dag_id=self.dag_id,
task_ids=[task.task_id],
states=self.STATES_COUNT_AS_RUNNING,
)

if num_running_task_instances_in_task >= task.task_concurrency:
raise TaskConcurrencyLimitReached(
"Not scheduling since Task concurrency limit "
"is reached."
)

_per_task_process(task, key, ti)
except (NoAvailablePoolSlot, DagConcurrencyLimitReached) as e:
except (NoAvailablePoolSlot, DagConcurrencyLimitReached, TaskConcurrencyLimitReached) as e:
self.log.debug(e)

# execute the tasks in the queue
Expand Down
92 changes: 82 additions & 10 deletions tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from airflow import AirflowException, models, settings
from airflow import configuration
from airflow.bin import cli
from airflow.exceptions import DagConcurrencyLimitReached, NoAvailablePoolSlot
from airflow.exceptions import DagConcurrencyLimitReached, NoAvailablePoolSlot, \
TaskConcurrencyLimitReached
from airflow.executors import BaseExecutor, SequentialExecutor
from airflow.jobs import BackfillJob, BaseJob, LocalTaskJob, SchedulerJob
from airflow.models import DAG, DagBag, DagModel, DagRun, Pool, SlaMiss, \
Expand Down Expand Up @@ -123,7 +124,7 @@ def abort():

class BackfillJobTest(unittest.TestCase):

def _get_dummy_dag(self, dag_id, pool=None):
def _get_dummy_dag(self, dag_id, pool=None, task_concurrency=None):
dag = DAG(
dag_id=dag_id,
start_date=DEFAULT_DATE,
Expand All @@ -133,6 +134,7 @@ def _get_dummy_dag(self, dag_id, pool=None):
DummyOperator(
task_id='op',
pool=pool,
task_concurrency=task_concurrency,
dag=dag)

dag.clear()
Expand Down Expand Up @@ -364,9 +366,61 @@ def test_backfill_conf(self):
self.assertEqual(conf, dr[0].conf)

@patch('airflow.jobs.LoggingMixin.log')
def test_backfill_respect_concurrency_limit(self, mock_log):
def test_backfill_respect_task_concurrency_limit(self, mock_log):
task_concurrency = 2
dag = self._get_dummy_dag(
'test_backfill_respect_task_concurrency_limit',
task_concurrency=task_concurrency,
)

executor = TestExecutor()

job = BackfillJob(
dag=dag,
executor=executor,
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=7),
)

job.run()

self.assertTrue(0 < len(executor.history))

task_concurrency_limit_reached_at_least_once = False

num_running_task_instances = 0
for running_task_instances in executor.history:
self.assertLessEqual(len(running_task_instances), task_concurrency)
num_running_task_instances += len(running_task_instances)
if len(running_task_instances) == task_concurrency:
task_concurrency_limit_reached_at_least_once = True

self.assertEquals(8, num_running_task_instances)
self.assertTrue(task_concurrency_limit_reached_at_least_once)

dag = self._get_dummy_dag('test_backfill_respect_concurrency_limit')
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
DagConcurrencyLimitReached,
)

times_pool_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
NoAvailablePoolSlot,
)

times_task_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
TaskConcurrencyLimitReached,
)

self.assertEquals(0, times_pool_limit_reached_in_debug)
self.assertEquals(0, times_dag_concurrency_limit_reached_in_debug)
self.assertGreater(times_task_concurrency_limit_reached_in_debug, 0)

@patch('airflow.jobs.LoggingMixin.log')
def test_backfill_respect_dag_concurrency_limit(self, mock_log):

dag = self._get_dummy_dag('test_backfill_respect_dag_concurrency_limit')
dag.concurrency = 2

executor = TestExecutor()
Expand Down Expand Up @@ -395,7 +449,7 @@ def test_backfill_respect_concurrency_limit(self, mock_log):
self.assertEquals(8, num_running_task_instances)
self.assertTrue(concurrency_limit_reached_at_least_once)

times_concurrency_limit_reached_in_debug = self._times_called_with(
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
DagConcurrencyLimitReached,
)
Expand All @@ -405,8 +459,14 @@ def test_backfill_respect_concurrency_limit(self, mock_log):
NoAvailablePoolSlot,
)

times_task_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
TaskConcurrencyLimitReached,
)

self.assertEquals(0, times_pool_limit_reached_in_debug)
self.assertGreater(times_concurrency_limit_reached_in_debug, 0)
self.assertEquals(0, times_task_concurrency_limit_reached_in_debug)
self.assertGreater(times_dag_concurrency_limit_reached_in_debug, 0)

@patch('airflow.jobs.LoggingMixin.log')
@patch('airflow.jobs.conf.getint')
Expand Down Expand Up @@ -456,7 +516,7 @@ def getint(section, key):
self.assertEquals(8, num_running_task_instances)
self.assertTrue(non_pooled_task_slot_count_reached_at_least_once)

times_concurrency_limit_reached_in_debug = self._times_called_with(
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
DagConcurrencyLimitReached,
)
Expand All @@ -466,7 +526,13 @@ def getint(section, key):
NoAvailablePoolSlot,
)

self.assertEquals(0, times_concurrency_limit_reached_in_debug)
times_task_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
TaskConcurrencyLimitReached,
)

self.assertEquals(0, times_dag_concurrency_limit_reached_in_debug)
self.assertEquals(0, times_task_concurrency_limit_reached_in_debug)
self.assertGreater(times_pool_limit_reached_in_debug, 0)

def test_backfill_pool_not_found(self):
Expand Down Expand Up @@ -533,7 +599,7 @@ def test_backfill_respect_pool_limit(self, mock_log):
self.assertEquals(8, num_running_task_instances)
self.assertTrue(pool_was_full_at_least_once)

times_concurrency_limit_reached_in_debug = self._times_called_with(
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
DagConcurrencyLimitReached,
)
Expand All @@ -543,7 +609,13 @@ def test_backfill_respect_pool_limit(self, mock_log):
NoAvailablePoolSlot,
)

self.assertEquals(0, times_concurrency_limit_reached_in_debug)
times_task_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
TaskConcurrencyLimitReached,
)

self.assertEquals(0, times_task_concurrency_limit_reached_in_debug)
self.assertEquals(0, times_dag_concurrency_limit_reached_in_debug)
self.assertGreater(times_pool_limit_reached_in_debug, 0)

def test_backfill_run_rescheduled(self):
Expand Down