Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close SQLite connection if session is deleted and thread is still running #189

Merged
merged 3 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions aiohttp_client_cache/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ async def _init_db(self):
)
return self._connection

def __del__(self):
"""If the aiosqlite connection is still open when this object is deleted, force its thread
to close by emptying its internal queue and setting its ``_running`` flag to ``False``.
This is basically a last resort to avoid hanging the application if this backend is used
without the CachedSession contextmanager.

Note: Since this uses internal attributes, it has the potential to break in future versions
of aiosqlite.
"""
if self._connection is not None:
self._connection._tx.queue.clear()
self._connection._running = False
self._connection = None

@asynccontextmanager
async def bulk_commit(self):
"""Contextmanager to more efficiently write a large number of records at once
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
def test(session):
"""Run tests for a specific python version"""
test_paths = session.posargs or [UNIT_TESTS]
session.install('.', 'pytest', 'pytest-xdist', 'requests-mock', 'timeout-decorator')
session.install('.', 'pytest', 'pytest-aiohttp', 'pytest-asyncio', 'pytest-xdist')

cmd = f'pytest -rs {XDIST_ARGS}'
session.run(*cmd.split(' '), *test_paths)
Expand Down
11 changes: 7 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ docs = ["furo", "linkify-it-py", "markdown-it-py", "myst-parser", "python

[tool.poetry.dev-dependencies]
# For unit + integration tests
async-timeout = ">=4.0"
brotli = ">=1.0"
pytest = ">=6.2"
pytest-aiohttp = "^0.3"
Expand Down
23 changes: 21 additions & 2 deletions test/integration/base_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from uuid import uuid4

import pytest
from async_timeout import timeout
from itsdangerous.exc import BadSignature
from itsdangerous.serializer import Serializer

Expand Down Expand Up @@ -35,13 +36,17 @@ class BaseBackendTest:

@asynccontextmanager
async def init_session(self, clear=True, **kwargs) -> AsyncIterator[CachedSession]:
session = await self._init_session(clear=clear, **kwargs)
async with session:
yield session

async def _init_session(self, clear=True, **kwargs) -> CachedSession:
kwargs.setdefault('allowed_methods', ALL_METHODS)
cache = self.backend_class(CACHE_NAME, **self.init_kwargs, **kwargs)
if clear:
await cache.clear()

async with CachedSession(cache=cache, **self.init_kwargs, **kwargs) as session:
yield session
return CachedSession(cache=cache, **self.init_kwargs, **kwargs)

@pytest.mark.parametrize('method', HTTPBIN_METHODS)
@pytest.mark.parametrize('field', ['params', 'data', 'json'])
Expand Down Expand Up @@ -100,6 +105,20 @@ async def get_url(mysession, url):
responses = await asyncio.gather(*tasks)
assert all([r.from_cache is True for r in responses])

async def test_without_contextmanager(self):
"""Test that the cache backend can be safely used without the CachedSession contextmanager.
An "unclosed ClientSession" warning is expected here, however.
"""
# Timeout to avoid hanging if the test fails
async with timeout(5.0):
session = await self._init_session()
await session.get(httpbin('get'))
del session

session = await self._init_session(clear=False)
r = await session.get(httpbin('get'))
assert r.from_cache is True

async def test_request__expire_after(self):
async with self.init_session() as session:
await session.get(httpbin('get'), expire_after=1)
Expand Down
8 changes: 8 additions & 0 deletions test/integration/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ async def test_content_reset(self):
content_2 = await cached_response_2.read()
assert content_1 == content_2 == original_content

async def test_without_contextmanager(self):
"""Test that the cache backend can be safely used without the CachedSession contextmanager.
An "unclosed ClientSession" warning is expected here, however.
"""
session = await self._init_session()
await session.get(httpbin('get'))
del session

# Serialization tests don't apply to in-memory cache
async def test_serializer__pickle(self):
pass
Expand Down
24 changes: 15 additions & 9 deletions test/integration/test_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
from contextlib import asynccontextmanager
from tempfile import gettempdir
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -53,17 +54,22 @@ async def test_concurrent_bulk_commit(self, mock_sqlite):
mock_connection = AsyncMock()
mock_sqlite.connect = AsyncMock(return_value=mock_connection)

async with self.init_cache() as cache:
@asynccontextmanager
async def bulk_commit_ctx():
async with self.init_cache() as cache:

async def bulk_commit_items(n_items):
async with cache.bulk_commit():
for i in range(n_items):
await cache.write(f'key_{n_items}_{i}', f'value_{i}')

async def bulk_commit_items(n_items):
async with cache.bulk_commit():
for i in range(n_items):
await cache.write(f'key_{n_items}_{i}', f'value_{i}')
yield bulk_commit_items

assert mock_connection.commit.call_count == 1
tasks = [asyncio.create_task(bulk_commit_items(n)) for n in [10, 100, 1000, 10000]]
await asyncio.gather(*tasks)
assert mock_connection.commit.call_count == 5
async with bulk_commit_ctx() as bulk_commit_items:
assert mock_connection.commit.call_count == 1
tasks = [asyncio.create_task(bulk_commit_items(n)) for n in [10, 100, 1000, 10000]]
await asyncio.gather(*tasks)
assert mock_connection.commit.call_count == 5

async def test_fast_save(self):
async with self.init_cache(index=1, fast_save=True) as cache_1, self.init_cache(
Expand Down
21 changes: 8 additions & 13 deletions test/unit/test_base_backend.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import pickle
from sys import version_info
from unittest.mock import MagicMock, patch

import pytest

from aiohttp_client_cache import CachedResponse
from aiohttp_client_cache.backends import CacheBackend, DictCache, get_placeholder_backend
from test.conftest import skip_37

TEST_URL = 'https://test.com'

pytestmark = pytest.mark.asyncio
skip_py37 = pytest.mark.skipif(
version_info < (3, 8), reason='Test requires AsyncMock from python 3.8+'
)


def get_mock_response(**kwargs):
response_kwargs = {
Expand Down Expand Up @@ -71,7 +66,7 @@ async def test_get_response__cache_miss(mock_delete):
mock_delete.assert_not_called()


@skip_py37
@skip_37
@patch.object(CacheBackend, 'delete')
@patch.object(CacheBackend, 'is_cacheable', return_value=False)
async def test_get_response__cache_expired(mock_is_cacheable, mock_delete):
Expand All @@ -84,7 +79,7 @@ async def test_get_response__cache_expired(mock_is_cacheable, mock_delete):
mock_delete.assert_called_with('request-key')


@skip_py37
@skip_37
@pytest.mark.parametrize('error_type', [AttributeError, KeyError, TypeError, pickle.PickleError])
@patch.object(CacheBackend, 'delete')
@patch.object(DictCache, 'read')
Expand All @@ -99,7 +94,7 @@ async def test_get_response__cache_invalid(mock_read, mock_delete, error_type):
mock_delete.assert_not_called()


@skip_py37
@skip_37
@patch.object(DictCache, 'read', return_value=object())
async def test_get_response__quiet_serde_error(mock_read):
"""Test for a quiet deserialization error in which no errors are raised but attributes are
Expand All @@ -113,7 +108,7 @@ async def test_get_response__quiet_serde_error(mock_read):
assert response is None


@skip_py37
@skip_37
async def test_save_response():
cache = CacheBackend()
mock_response = get_mock_response()
Expand All @@ -126,7 +121,7 @@ async def test_save_response():
assert await cache.redirects.read(redirect_key) == 'key'


@skip_py37
@skip_37
async def test_save_response__manual_save():
"""Manually save a response with no cache key provided"""
cache = CacheBackend()
Expand Down Expand Up @@ -193,7 +188,7 @@ async def test_has_url():
assert not await cache.has_url('https://test.com/some_other_path')


@skip_py37
@skip_37
@patch('aiohttp_client_cache.backends.base.create_key')
async def test_create_key(mock_create_key):
"""Actual logic is in cache_keys module; just test to make sure it gets called correctly"""
Expand Down Expand Up @@ -244,7 +239,7 @@ async def test_is_cacheable(method, status, disabled, expired, filter_return, ex
assert await cache.is_cacheable(mock_response) is expected_result


@skip_py37
@skip_37
@pytest.mark.parametrize(
'method, status, disabled, expired, body, expected_result',
[
Expand Down
Loading