Skip to content

Commit

Permalink
Merge nowait and skip_locked into with_row_locks (#36889)
Browse files Browse the repository at this point in the history
Since the two functions are always used in conjunction with the last, we
can simply handle the two arguments specially in with_row_locks, instead
of doing the same checks over and over again.

The two functions are removed outright since they are not documented and
thus technically not subject to backward compatibility. I highly doubt
anyone is using them directly due to their highly specific nature.
  • Loading branch information
uranusjr authored Jan 19, 2024
1 parent 78b179f commit cabd768
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 140 deletions.
6 changes: 2 additions & 4 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
)
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, with_row_locks
from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks

if TYPE_CHECKING:
from multiprocessing.connection import Connection as MultiprocessingConnection
Expand Down Expand Up @@ -681,9 +681,7 @@ def _fetch_callbacks_with_retries(self, max_callbacks: int, session: Session):
DbCallbackRequest.processor_subdir == self.get_dag_directory(),
)
query = query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks)
query = with_row_locks(
query, of=DbCallbackRequest, session=session, **skip_locked(session=session)
)
query = with_row_locks(query, of=DbCallbackRequest, session=session, skip_locked=True)
callbacks = session.scalars(query)
for callback in callbacks:
try:
Expand Down
21 changes: 4 additions & 17 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
from airflow.utils.sqlalchemy import (
is_lock_not_available_error,
prohibit_commit,
skip_locked,
tuple_in_condition,
with_row_locks,
)
Expand Down Expand Up @@ -399,12 +398,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
timer.start()

try:
query = with_row_locks(
query,
of=TI,
session=session,
**skip_locked(session=session),
)
query = with_row_locks(query, of=TI, session=session, skip_locked=True)
task_instances_to_examine: list[TI] = session.scalars(query).all()

timer.stop(send=True)
Expand Down Expand Up @@ -706,12 +700,7 @@ def _process_executor_events(self, session: Session) -> int:
query = select(TI).where(filter_for_tis).options(selectinload(TI.dag_model))
# row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have
# multi-schedulers
tis_query: Query = with_row_locks(
query,
of=TI,
session=session,
**skip_locked(session=session),
)
tis_query: Query = with_row_locks(query, of=TI, session=session, skip_locked=True)
tis: Iterator[TI] = session.scalars(tis_query)
for ti in tis:
try_number = ti_primary_key_to_try_number_map[ti.key.primary]
Expand Down Expand Up @@ -1434,7 +1423,7 @@ def _schedule_dag_run(
select(DagModel).where(DagModel.dag_id == dag_run.dag_id).options(joinedload(DagModel.parent_dag))
)
dag_model = session.scalars(
with_row_locks(query, of=DagModel, session=session, **skip_locked(session=session))
with_row_locks(query, of=DagModel, session=session, skip_locked=True)
).one_or_none()

if not dag:
Expand Down Expand Up @@ -1660,9 +1649,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int:
)

# Lock these rows, so that another scheduler can't try and adopt these too
tis_to_adopt_or_reset = with_row_locks(
query, of=TI, session=session, **skip_locked(session=session)
)
tis_to_adopt_or_reset = with_row_locks(query, of=TI, session=session, skip_locked=True)
tis_to_adopt_or_reset = session.scalars(tis_to_adopt_or_reset).all()
to_reset = self.job.executor.try_adopt_task_instances(tis_to_adopt_or_reset)

Expand Down
4 changes: 2 additions & 2 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from airflow.utils.log.secrets_masker import redact
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.sqlalchemy import skip_locked, with_row_locks
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule
Expand Down Expand Up @@ -625,7 +625,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
TaskInstance.run_id == run_id,
TaskInstance.map_index >= total_expanded_ti_count,
)
query = with_row_locks(query, of=TaskInstance, session=session, **skip_locked(session=session))
query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True)
to_update = session.scalars(query)
for ti in to_update:
ti.state = TaskInstanceState.REMOVED
Expand Down
3 changes: 1 addition & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@
Interval,
UtcDateTime,
lock_rows,
skip_locked,
tuple_in_condition,
with_row_locks,
)
Expand Down Expand Up @@ -3789,7 +3788,7 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[
)

return (
session.scalars(with_row_locks(query, of=cls, session=session, **skip_locked(session=session))),
session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)),
dataset_triggered_dag_info,
)

Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from airflow.utils.helpers import chunks, is_container, prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, tuple_in_condition, with_row_locks
from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, tuple_in_condition, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import NOTSET, DagRunType

Expand Down Expand Up @@ -365,7 +365,7 @@ def next_dagruns_to_examine(
query = query.where(DagRun.execution_date <= func.now())

return session.scalars(
with_row_locks(query.limit(max_number), of=cls, session=session, **skip_locked(session=session))
with_row_locks(query.limit(max_number), of=cls, session=session, skip_locked=True)
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from airflow.typing_compat import TypedDict
from airflow.utils.db import exists_query
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import nowait, with_row_locks
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
Expand Down Expand Up @@ -172,7 +172,7 @@ def slots_stats(
query = select(Pool.pool, Pool.slots, Pool.include_deferred)

if lock_rows:
query = with_row_locks(query, session=session, **nowait(session))
query = with_row_locks(query, session=session, nowait=True)

pool_rows = session.execute(query)
for pool_name, total_slots, include_deferred in pool_rows:
Expand Down
72 changes: 27 additions & 45 deletions airflow/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,46 +334,6 @@ def process_result_value(self, value, dialect):
return data


def skip_locked(session: Session) -> dict[str, Any]:
"""
Return kargs for passing to `with_for_update()` suitable for the current DB engine version.
We do this as we document the fact that on DB engines that don't support this construct, we do not
support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still
work, just slightly slower in some circumstances.
Specifically don't emit SKIP LOCKED for MySQL < 8, or MariaDB, neither of which support this construct
See https://jira.mariadb.org/browse/MDEV-13115
"""
dialect = session.bind.dialect

if dialect.name != "mysql" or dialect.supports_for_update_of:
return {"skip_locked": True}
else:
return {}


def nowait(session: Session) -> dict[str, Any]:
"""
Return kwargs for passing to `with_for_update()` suitable for the current DB engine version.
We do this as we document the fact that on DB engines that don't support this construct, we do not
support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still
work, just slightly slower in some circumstances.
Specifically don't emit NOWAIT for MySQL < 8, or MariaDB, neither of which support this construct
See https://jira.mariadb.org/browse/MDEV-13115
"""
dialect = session.bind.dialect

if dialect.name != "mysql" or dialect.supports_for_update_of:
return {"nowait": True}
else:
return {}


def nulls_first(col, session: Session) -> dict[str, Any]:
"""Specify *NULLS FIRST* to the column ordering.
Expand All @@ -390,22 +350,44 @@ def nulls_first(col, session: Session) -> dict[str, Any]:
USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True)


def with_row_locks(query: Query, session: Session, **kwargs) -> Query:
def with_row_locks(
query: Query,
session: Session,
*,
nowait: bool = False,
skip_locked: bool = False,
**kwargs,
) -> Query:
"""
Apply with_for_update to an SQLAlchemy query, if row level locking is in use.
Apply with_for_update to the SQLAlchemy query if row level locking is in use.
This wrapper is needed so we don't use the syntax on unsupported database
engines. In particular, MySQL (prior to 8.0) and MariaDB do not support
row locking, where we do not support nor recommend running HA scheduler. If
a user ignores this and tries anyway, everything will still work, just
slightly slower in some circumstances.
See https://jira.mariadb.org/browse/MDEV-13115
:param query: An SQLAlchemy Query object
:param session: ORM Session
:param nowait: If set to True, will pass NOWAIT to supported database backends.
:param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends.
:param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
:return: updated query
"""
dialect = session.bind.dialect

# Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
if USE_ROW_LEVEL_LOCKING and (dialect.name != "mysql" or dialect.supports_for_update_of):
return query.with_for_update(**kwargs)
else:
if not USE_ROW_LEVEL_LOCKING:
return query
if dialect.name == "mysql" and not dialect.supports_for_update_of:
return query
if nowait:
kwargs["nowait"] = True
if skip_locked:
kwargs["skip_locked"] = True
return query.with_for_update(**kwargs)


@contextlib.contextmanager
Expand Down
66 changes: 0 additions & 66 deletions tests/utils/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@
from airflow.utils.sqlalchemy import (
ExecutorConfigType,
ensure_pod_is_valid_after_unpickling,
nowait,
prohibit_commit,
skip_locked,
with_row_locks,
)
from airflow.utils.state import State
Expand Down Expand Up @@ -117,70 +115,6 @@ def test_process_bind_param_naive(self):
)
dag.clear()

@pytest.mark.parametrize(
"dialect, supports_for_update_of, expected_return_value",
[
(
"postgresql",
True,
{"skip_locked": True},
),
(
"mysql",
False,
{},
),
(
"mysql",
True,
{"skip_locked": True},
),
(
"sqlite",
False,
{"skip_locked": True},
),
],
)
def test_skip_locked(self, dialect, supports_for_update_of, expected_return_value):
session = mock.Mock()
session.bind.dialect.name = dialect
session.bind.dialect.supports_for_update_of = supports_for_update_of
assert skip_locked(session=session) == expected_return_value

@pytest.mark.parametrize(
"dialect, supports_for_update_of, expected_return_value",
[
(
"postgresql",
True,
{"nowait": True},
),
(
"mysql",
False,
{},
),
(
"mysql",
True,
{"nowait": True},
),
(
"sqlite",
False,
{
"nowait": True,
},
),
],
)
def test_nowait(self, dialect, supports_for_update_of, expected_return_value):
session = mock.Mock()
session.bind.dialect.name = dialect
session.bind.dialect.supports_for_update_of = supports_for_update_of
assert nowait(session=session) == expected_return_value

@pytest.mark.parametrize(
"dialect, supports_for_update_of, use_row_level_lock_conf, expected_use_row_level_lock",
[
Expand Down

0 comments on commit cabd768

Please sign in to comment.