diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index c3dbfb895a47d..ad415b4ada842 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1125,7 +1125,10 @@ def _critical_section_execute_task_instances(self, session: Session) -> int: :type session: sqlalchemy.orm.Session :return: Number of task instance with state changed. """ - max_tis = min(self.max_tis_per_query, self.executor.slots_available) + if self.max_tis_per_query == 0: + max_tis = self.executor.slots_available + else: + max_tis = min(self.max_tis_per_query, self.executor.slots_available) queued_tis = self._executable_task_instances_to_queued(max_tis, session=session) self._enqueue_task_instances_with_queued_state(queued_tis) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index d0d565d5599df..c5340e801761a 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -1681,6 +1681,56 @@ def test_execute_task_instances_limit(self): ti.refresh_from_db() assert State.QUEUED == ti.state + def test_execute_task_instances_unlimited(self): + """Test that max_tis_per_query=0 is unlimited""" + + dag_id = 'SchedulerJobTest.test_execute_task_instances_unlimited' + task_id_1 = 'dummy_task' + task_id_2 = 'dummy_task_2' + + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=1024) + task1 = DummyOperator(dag=dag, task_id=task_id_1) + task2 = DummyOperator(dag=dag, task_id=task_id_2) + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + scheduler = SchedulerJob(subdir=os.devnull) + session = settings.Session() + + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + date = dag.start_date + tis = [] + for _ in range(0, 20): + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + ti1 = TaskInstance(task1, dr.execution_date) + ti2 = TaskInstance(task2, dr.execution_date) + tis.append(ti1) + tis.append(ti2) + ti1.refresh_from_db() + ti2.refresh_from_db() + ti1.state = State.SCHEDULED + ti2.state = State.SCHEDULED + session.merge(ti1) + session.merge(ti2) + session.flush() + scheduler.max_tis_per_query = 0 + scheduler.executor = MagicMock(slots_available=36) + + res = scheduler._critical_section_execute_task_instances(session) + # 20 dag runs * 2 tasks each = 40, but limited by number of slots available + self.assertEqual(36, res) + session.rollback() + def test_change_state_for_tis_without_dagrun(self): dag1 = DAG(dag_id='test_change_state_for_tis_without_dagrun', start_date=DEFAULT_DATE)