Skip to content

Commit

Permalink
Add TimeoutError to be a retryable error in databricks provider (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
rawwar authored Oct 18, 2024
1 parent 424cf50 commit 0de5587
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import copy
import platform
import time
from asyncio.exceptions import TimeoutError
from functools import cached_property
from typing import TYPE_CHECKING, Any
from urllib.parse import urlsplit
Expand Down Expand Up @@ -679,7 +680,7 @@ def _retryable_error(exception: BaseException) -> bool:
if exception.status >= 500 or exception.status == 429:
return True

if isinstance(exception, ClientConnectorError):
if isinstance(exception, (ClientConnectorError, TimeoutError)):
return True

return False
Expand Down
11 changes: 11 additions & 0 deletions providers/tests/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import ssl
import time
from asyncio.exceptions import TimeoutError
from unittest import mock
from unittest.mock import AsyncMock

Expand Down Expand Up @@ -1551,6 +1552,16 @@ async def test_do_api_call_retries_with_client_connector_error(self, mock_get):
await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
assert mock_errors.call_count == DEFAULT_RETRY_NUMBER

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_do_api_call_retries_with_client_timeout_error(self, mock_get):
mock_get.side_effect = TimeoutError()
with mock.patch.object(self.hook.log, "error") as mock_errors:
async with self.hook:
with pytest.raises(AirflowException):
await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
assert mock_errors.call_count == DEFAULT_RETRY_NUMBER

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_do_api_call_retries_with_retryable_error(self, mock_get):
Expand Down

0 comments on commit 0de5587

Please sign in to comment.