diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 47ee70bbe2be1b..3b213abf66045f 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -49,6 +49,7 @@ from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey +from airflow.settings import run_with_db_retries from airflow.stats import Stats from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.utils import timezone @@ -1472,15 +1473,9 @@ def _do_scheduling(self, session) -> int: with prohibit_commit(session) as guard: if settings.USE_JOB_SCHEDULE: - query = DagModel.dags_needing_dagruns(session) - self._create_dag_runs(query.all(), session) - - # commit the session - Release the write lock on DagModel table. - guard.commit() - # END: create dagruns - - dag_runs = DagRun.next_dagruns_to_examine(session) + self._create_dagruns_for_dags(guard, session) + dag_runs = self._get_next_dagruns_to_examine(session) # Bulk fetch the currently active dag runs for the dags we are # examining, rather than making one query per DagRun @@ -1560,6 +1555,46 @@ def _do_scheduling(self, session) -> int: guard.commit() return num_queued_tis + def _get_next_dagruns_to_examine(self, session): + """Get Next DagRuns to Examine with retries""" + for attempt in run_with_db_retries(logger=self.log): + with attempt: + try: + self.log.debug( + "Running SchedulerJob._get_dagmodels_and_create_dagruns with retries. " + "Try %d of %d", + attempt.retry_state.attempt_number, + settings.MAX_DB_RETRIES, + ) + dag_runs = DagRun.next_dagruns_to_examine(session) + + except OperationalError: + session.rollback() + raise + + return dag_runs + + def _create_dagruns_for_dags(self, guard, session): + """Find Dag Models needing DagRuns and Create Dag Runs with retries in case of OperationalError""" + for attempt in run_with_db_retries(logger=self.log): + with attempt: + try: + self.log.debug( + "Running SchedulerJob._create_dagruns_for_dags with retries. " "Try %d of %d", + attempt.retry_state.attempt_number, + settings.MAX_DB_RETRIES, + ) + query = DagModel.dags_needing_dagruns(session) + self._create_dag_runs(query.all(), session) + + # commit the session - Release the write lock on DagModel table. + guard.commit() + # END: create dagruns + + except OperationalError: + session.rollback() + raise + def _create_dag_runs(self, dag_models: Iterable[DagModel], session: Session) -> None: """ Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control @@ -1797,63 +1832,78 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): self.log.info("Resetting orphaned tasks for active dag runs") timeout = conf.getint('scheduler', 'scheduler_health_check_threshold') - num_failed = ( - session.query(SchedulerJob) - .filter( - SchedulerJob.state == State.RUNNING, - SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)), - ) - .update({"state": State.FAILED}) - ) - - if num_failed: - self.log.info("Marked %d SchedulerJob instances as failed", num_failed) - Stats.incr(self.__class__.__name__.lower() + '_end', num_failed) - - resettable_states = [State.SCHEDULED, State.QUEUED, State.RUNNING] - query = ( - session.query(TI) - .filter(TI.state.in_(resettable_states)) - # outerjoin is because we didn't use to have queued_by_job - # set, so we need to pick up anything pre upgrade. This (and the - # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is - # released. - .outerjoin(TI.queued_by_job) - .filter(or_(TI.queued_by_job_id.is_(None), SchedulerJob.state != State.RUNNING)) - .join(TI.dag_run) - .filter( - DagRun.run_type != DagRunType.BACKFILL_JOB, - # pylint: disable=comparison-with-callable - DagRun.state == State.RUNNING, - ) - .options(load_only(TI.dag_id, TI.task_id, TI.execution_date)) - ) - - # Lock these rows, so that another scheduler can't try and adopt these too - tis_to_reset_or_adopt = with_row_locks( - query, of=TI, session=session, **skip_locked(session=session) - ).all() - to_reset = self.executor.try_adopt_task_instances(tis_to_reset_or_adopt) - - reset_tis_message = [] - for ti in to_reset: - reset_tis_message.append(repr(ti)) - ti.state = State.NONE - ti.queued_by_job_id = None - - for ti in set(tis_to_reset_or_adopt) - set(to_reset): - ti.queued_by_job_id = self.id + for attempt in run_with_db_retries(logger=self.log): + with attempt: + self.log.debug( + "Running SchedulerJob.adopt_or_reset_orphaned_tasks with retries. Try %d of %d", + attempt.retry_state.attempt_number, + settings.MAX_DB_RETRIES, + ) + self.log.debug("Calling SchedulerJob.adopt_or_reset_orphaned_tasks method") + try: + num_failed = ( + session.query(SchedulerJob) + .filter( + SchedulerJob.state == State.RUNNING, + SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)), + ) + .update({"state": State.FAILED}) + ) - Stats.incr('scheduler.orphaned_tasks.cleared', len(to_reset)) - Stats.incr('scheduler.orphaned_tasks.adopted', len(tis_to_reset_or_adopt) - len(to_reset)) + if num_failed: + self.log.info("Marked %d SchedulerJob instances as failed", num_failed) + Stats.incr(self.__class__.__name__.lower() + '_end', num_failed) + + resettable_states = [State.SCHEDULED, State.QUEUED, State.RUNNING] + query = ( + session.query(TI) + .filter(TI.state.in_(resettable_states)) + # outerjoin is because we didn't use to have queued_by_job + # set, so we need to pick up anything pre upgrade. This (and the + # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is + # released. + .outerjoin(TI.queued_by_job) + .filter(or_(TI.queued_by_job_id.is_(None), SchedulerJob.state != State.RUNNING)) + .join(TI.dag_run) + .filter( + DagRun.run_type != DagRunType.BACKFILL_JOB, + # pylint: disable=comparison-with-callable + DagRun.state == State.RUNNING, + ) + .options(load_only(TI.dag_id, TI.task_id, TI.execution_date)) + ) - if to_reset: - task_instance_str = '\n\t'.join(reset_tis_message) - self.log.info( - "Reset the following %s orphaned TaskInstances:\n\t%s", len(to_reset), task_instance_str - ) + # Lock these rows, so that another scheduler can't try and adopt these too + tis_to_reset_or_adopt = with_row_locks( + query, of=TI, session=session, **skip_locked(session=session) + ).all() + to_reset = self.executor.try_adopt_task_instances(tis_to_reset_or_adopt) + + reset_tis_message = [] + for ti in to_reset: + reset_tis_message.append(repr(ti)) + ti.state = State.NONE + ti.queued_by_job_id = None + + for ti in set(tis_to_reset_or_adopt) - set(to_reset): + ti.queued_by_job_id = self.id + + Stats.incr('scheduler.orphaned_tasks.cleared', len(to_reset)) + Stats.incr('scheduler.orphaned_tasks.adopted', len(tis_to_reset_or_adopt) - len(to_reset)) + + if to_reset: + task_instance_str = '\n\t'.join(reset_tis_message) + self.log.info( + "Reset the following %s orphaned TaskInstances:\n\t%s", + len(to_reset), + task_instance_str, + ) + + # Issue SQL/finish "Unit of Work", but let @provide_session + # commit (or if passed a session, let caller decide when to commit + session.flush() + except OperationalError: + session.rollback() + raise - # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller - # decide when to commit - session.flush() return len(to_reset) diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index b493334cc20b31..4a96f4fdaf0af3 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -20,7 +20,6 @@ import importlib import importlib.machinery import importlib.util -import logging import os import sys import textwrap @@ -30,7 +29,6 @@ from datetime import datetime, timedelta from typing import Dict, List, NamedTuple, Optional -import tenacity from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session @@ -39,6 +37,7 @@ from airflow import settings from airflow.configuration import conf from airflow.exceptions import AirflowClusterPolicyViolation, AirflowDagCycleException, SerializedDagNotFound +from airflow.settings import run_with_db_retries from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.dag_cycle_tester import test_cycle @@ -550,13 +549,7 @@ def _serialze_dag_capturing_errors(dag, session): # Retry 'DAG.bulk_write_to_db' & 'SerializedDagModel.bulk_sync_to_db' in case # of any Operational Errors # In case of failures, provide_session handles rollback - for attempt in tenacity.Retrying( - retry=tenacity.retry_if_exception_type(exception_types=OperationalError), - wait=tenacity.wait_random_exponential(multiplier=0.5, max=5), - stop=tenacity.stop_after_attempt(settings.MAX_DB_RETRIES), - before_sleep=tenacity.before_sleep_log(self.log, logging.DEBUG), - reraise=True, - ): + for attempt in run_with_db_retries(logger=self.log): with attempt: serialize_errors = [] self.log.debug( diff --git a/airflow/settings.py b/airflow/settings.py index d46c0d181d4fbd..85874cdaa727c0 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -25,8 +25,10 @@ from typing import Optional import pendulum +import tenacity from sqlalchemy import create_engine, exc from sqlalchemy.engine import Engine +from sqlalchemy.exc import OperationalError from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm.session import Session as SASession from sqlalchemy.pool import NullPool @@ -484,8 +486,7 @@ def initialize(): # Number of times, the code should be retried in case of DB Operational Errors # Retries are done using tenacity. Not all transactions should be retried as it can cause # undesired state. -# Currently used in the following places: -# `DagFileProcessor.process_file` to retry `dagbag.sync_to_db` +# Currently used in settings.run_with_db_retries MAX_DB_RETRIES = conf.getint('core', 'max_db_retries', fallback=3) USE_JOB_SCHEDULE = conf.getboolean('scheduler', 'use_job_schedule', fallback=True) @@ -504,3 +505,18 @@ def initialize(): executor_constants.KUBERNETES_EXECUTOR, executor_constants.CELERY_KUBERNETES_EXECUTOR, } + + +def run_with_db_retries(logger: logging.Logger, **kwargs): + """Return Tenacity Retrying object with project specific default""" + # Default kwargs + retry_kwargs = dict( + retry=tenacity.retry_if_exception_type(exception_types=OperationalError), + wait=tenacity.wait_random_exponential(multiplier=0.5, max=5), + stop=tenacity.stop_after_attempt(MAX_DB_RETRIES), + before_sleep=tenacity.before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + retry_kwargs.update(kwargs) + + return tenacity.Retrying(**retry_kwargs)