Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rollback for all retry exceptions (#40882) #40883

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow/utils/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from inspect import signature
from typing import Callable, TypeVar, overload

from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.exc import DBAPIError

from airflow.configuration import conf

Expand All @@ -36,7 +36,7 @@ def run_with_db_retries(max_retries: int = MAX_DB_RETRIES, logger: logging.Logge

# Default kwargs
retry_kwargs = dict(
retry=tenacity.retry_if_exception_type(exception_types=(OperationalError, DBAPIError)),
retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError)),
wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
stop=tenacity.stop_after_attempt(max_retries),
reraise=True,
Expand All @@ -58,7 +58,7 @@ def retry_db_transaction(_func: F) -> F: ...

def retry_db_transaction(_func: Callable | None = None, *, retries: int = MAX_DB_RETRIES, **retry_kwargs):
"""
Retry functions in case of ``OperationalError`` from DB.
Retry functions in case of ``DBAPIError`` from DB.

It should not be used with ``@provide_session``.
"""
Expand Down Expand Up @@ -96,7 +96,7 @@ def wrapped_function(*args, **kwargs):
)
try:
return func(*args, **kwargs)
except OperationalError:
except DBAPIError:
session.rollback()
raise

Expand Down
15 changes: 10 additions & 5 deletions tests/utils/test_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from unittest import mock

import pytest
from sqlalchemy.exc import OperationalError
from sqlalchemy.exc import InternalError, OperationalError

from airflow.utils.retries import retry_db_transaction

if TYPE_CHECKING:
from sqlalchemy.exc import DBAPIError


class TestRetries:
def test_retry_db_transaction_with_passing_retries(self):
Expand All @@ -45,23 +49,24 @@ def test_function(session):
assert mock_obj.call_count == 2

@pytest.mark.db_test
def test_retry_db_transaction_with_default_retries(self, caplog):
@pytest.mark.parametrize("excection_type", [OperationalError, InternalError])
def test_retry_db_transaction_with_default_retries(self, caplog, excection_type: type[DBAPIError]):
"""Test that by default 3 retries will be carried out"""
mock_obj = mock.MagicMock()
mock_session = mock.MagicMock()
mock_rollback = mock.MagicMock()
mock_session.rollback = mock_rollback
op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)
db_error = excection_type(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)

@retry_db_transaction
def test_function(session):
session.execute("select 1")
mock_obj(2)
raise op_error
raise db_error

caplog.set_level(logging.DEBUG, logger=self.__module__)
caplog.clear()
with pytest.raises(OperationalError):
with pytest.raises(excection_type):
test_function(session=mock_session)

for try_no in range(1, 4):
Expand Down