Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
elprans committed Oct 18, 2024
1 parent cd766cb commit 0d616f2
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 21 deletions.
16 changes: 11 additions & 5 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,22 +262,28 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=60.0,
connect=None,
setup=None,
init=None,
loop=None,
pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection,
record_class=asyncpg.Record,
connect_fn=pg_connection.connect,
**connect_kwargs):
return pool_class(
dsn,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
loop=loop,
connect=connect,
setup=setup,
init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
record_class=record_class, connect_fn=connect_fn,
**connect_kwargs)
record_class=record_class,
**connect_kwargs,
)


class ClusterTestCase(TestCase):
Expand Down
57 changes: 42 additions & 15 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class Pool:

__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_init', '_connect_fn', '_connect_args', '_connect_kwargs',
'_init', '_connect', '_connect_args', '_connect_kwargs',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
Expand All @@ -324,12 +324,12 @@ def __init__(self, *connect_args,
max_size,
max_queries,
max_inactive_connection_lifetime,
setup,
init,
connect = None,
setup = None,
init = None,
loop,
connection_class,
record_class,
connect_fn,
**connect_kwargs):

if len(connect_args) > 1:
Expand Down Expand Up @@ -386,12 +386,14 @@ def __init__(self, *connect_args,
self._closing = False
self._closed = False
self._generation = 0
self._init = init

self._connect = connect if connect is not None else connection.connect
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs
self._connect_fn = connect_fn

self._setup = setup
self._init = init

self._max_queries = max_queries
self._max_inactive_connection_lifetime = \
max_inactive_connection_lifetime
Expand Down Expand Up @@ -505,13 +507,25 @@ def set_connect_args(self, dsn=None, **connect_kwargs):
self._connect_kwargs = connect_kwargs

async def _get_new_connection(self):
con = await self._connect_fn(
con = await self._connect(
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
record_class=self._record_class,
**self._connect_kwargs,
)
if not isinstance(con, self._connection_class):
good = self._connection_class
good_n = f'{good.__module__}.{good.__name__}'
bad = type(con)
if bad.__module__ == "builtins":
bad_n = bad.__name__
else:
bad_n = f'{bad.__module__}.{bad.__name__}'
raise exceptions.InterfaceError(
"expected pool connect callback to return an instance of "
f"'{good_n}', got " f"'{bad_n}'"
)

if self._init is not None:
try:
Expand Down Expand Up @@ -1003,6 +1017,7 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
connect=None,
setup=None,
init=None,
loop=None,
Expand Down Expand Up @@ -1085,6 +1100,13 @@ def create_pool(dsn=None, *,
Number of seconds after which inactive connections in the
pool will be closed. Pass ``0`` to disable this mechanism.
:param coroutine connect:
A coroutine that is called instead of
:func:`~asyncpg.connection.connect` whenever the pool needs to make a
new connection. Must return an instance of type specified by
*connection_class* or :class:`~asyncpg.connection.Connection` if
*connection_class* was not specified.
:param coroutine setup:
A coroutine to prepare a connection right before it is returned
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
Expand All @@ -1099,10 +1121,6 @@ def create_pool(dsn=None, *,
or :meth:`Connection.set_type_codec() <\
asyncpg.connection.Connection.set_type_codec>`.
:param coroutine connect_fn:
A coroutine with signature identical to :func:`~asyncpg.connection.connect`. This can be used to add custom
authentication or ssl logic when creating a connection, as is required by GCP's cloud-sql-python-connector.
:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
Expand All @@ -1129,12 +1147,21 @@ def create_pool(dsn=None, *,
.. versionchanged:: 0.22.0
Added the *record_class* parameter.
.. versionchanged:: 0.30.0
Added the *connect* parameter.
"""
return Pool(
dsn,
connection_class=connection_class,
record_class=record_class, connect_fn=connection.connect,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
record_class=record_class,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
loop=loop,
connect=connect,
setup=setup,
init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
**connect_kwargs)
**connect_kwargs,
)
21 changes: 20 additions & 1 deletion tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ async def setup(con):

async def test_pool_07(self):
cons = set()
connect_called = 0

async def connect(*args, **kwargs):
nonlocal connect_called
connect_called += 1
return await pg_connection.connect(*args, **kwargs)

async def setup(con):
if con._con not in cons: # `con` is `PoolConnectionProxy`.
Expand All @@ -152,13 +158,26 @@ async def user(pool):
raise RuntimeError('init was not called')

async with self.create_pool(database='postgres',
min_size=2, max_size=5,
min_size=2,
max_size=5,
connect=connect,
init=init,
setup=setup) as pool:
users = asyncio.gather(*[user(pool) for _ in range(10)])
await users

self.assertEqual(len(cons), 5)
self.assertEqual(connect_called, 5)

async def bad_connect(*args, **kwargs):
return 1

with self.assertRaisesRegex(
asyncpg.InterfaceError,
"expected pool connect callback to return an instance of "
"'asyncpg\\.connection\\.Connection', got 'int'"
):
await self.create_pool(database='postgres', connect=bad_connect)

async def test_pool_08(self):
pool = await self.create_pool(database='postgres',
Expand Down

0 comments on commit 0d616f2

Please sign in to comment.