diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index 83101cea..ab208acf 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -262,22 +262,28 @@ def create_pool(dsn=None, *, max_size=10, max_queries=50000, max_inactive_connection_lifetime=60.0, + connect=None, setup=None, init=None, loop=None, pool_class=pg_pool.Pool, connection_class=pg_connection.Connection, record_class=asyncpg.Record, - connect_fn=pg_connection.connect, **connect_kwargs): return pool_class( dsn, - min_size=min_size, max_size=max_size, - max_queries=max_queries, loop=loop, setup=setup, init=init, + min_size=min_size, + max_size=max_size, + max_queries=max_queries, + loop=loop, + connect=connect, + setup=setup, + init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, connection_class=connection_class, - record_class=record_class, connect_fn=connect_fn, - **connect_kwargs) + record_class=record_class, + **connect_kwargs, + ) class ClusterTestCase(TestCase): diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 1e9a0457..33279aaa 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -313,7 +313,7 @@ class Pool: __slots__ = ( '_queue', '_loop', '_minsize', '_maxsize', - '_init', '_connect_fn', '_connect_args', '_connect_kwargs', + '_init', '_connect', '_connect_args', '_connect_kwargs', '_holders', '_initialized', '_initializing', '_closing', '_closed', '_connection_class', '_record_class', '_generation', '_setup', '_max_queries', '_max_inactive_connection_lifetime' @@ -324,12 +324,12 @@ def __init__(self, *connect_args, max_size, max_queries, max_inactive_connection_lifetime, - setup, - init, + connect = None, + setup = None, + init = None, loop, connection_class, record_class, - connect_fn, **connect_kwargs): if len(connect_args) > 1: @@ -386,12 +386,14 @@ def __init__(self, *connect_args, self._closing = False self._closed = False self._generation = 0 - self._init = init + + self._connect = connect if connect is not None else connection.connect self._connect_args = connect_args self._connect_kwargs = connect_kwargs - self._connect_fn = connect_fn self._setup = setup + self._init = init + self._max_queries = max_queries self._max_inactive_connection_lifetime = \ max_inactive_connection_lifetime @@ -505,13 +507,25 @@ def set_connect_args(self, dsn=None, **connect_kwargs): self._connect_kwargs = connect_kwargs async def _get_new_connection(self): - con = await self._connect_fn( + con = await self._connect( *self._connect_args, loop=self._loop, connection_class=self._connection_class, record_class=self._record_class, **self._connect_kwargs, ) + if not isinstance(con, self._connection_class): + good = self._connection_class + good_n = f'{good.__module__}.{good.__name__}' + bad = type(con) + if bad.__module__ == "builtins": + bad_n = bad.__name__ + else: + bad_n = f'{bad.__module__}.{bad.__name__}' + raise exceptions.InterfaceError( + "expected pool connect callback to return an instance of " + f"'{good_n}', got " f"'{bad_n}'" + ) if self._init is not None: try: @@ -1003,6 +1017,7 @@ def create_pool(dsn=None, *, max_size=10, max_queries=50000, max_inactive_connection_lifetime=300.0, + connect=None, setup=None, init=None, loop=None, @@ -1085,6 +1100,13 @@ def create_pool(dsn=None, *, Number of seconds after which inactive connections in the pool will be closed. Pass ``0`` to disable this mechanism. + :param coroutine connect: + A coroutine that is called instead of + :func:`~asyncpg.connection.connect` whenever the pool needs to make a + new connection. Must return an instance of type specified by + *connection_class* or :class:`~asyncpg.connection.Connection` if + *connection_class* was not specified. + :param coroutine setup: A coroutine to prepare a connection right before it is returned from :meth:`Pool.acquire() `. An example use @@ -1099,10 +1121,6 @@ def create_pool(dsn=None, *, or :meth:`Connection.set_type_codec() <\ asyncpg.connection.Connection.set_type_codec>`. - :param coroutine connect_fn: - A coroutine with signature identical to :func:`~asyncpg.connection.connect`. This can be used to add custom - authentication or ssl logic when creating a connection, as is required by GCP's cloud-sql-python-connector. - :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. @@ -1129,12 +1147,21 @@ def create_pool(dsn=None, *, .. versionchanged:: 0.22.0 Added the *record_class* parameter. + + .. versionchanged:: 0.30.0 + Added the *connect* parameter. """ return Pool( dsn, connection_class=connection_class, - record_class=record_class, connect_fn=connection.connect, - min_size=min_size, max_size=max_size, - max_queries=max_queries, loop=loop, setup=setup, init=init, + record_class=record_class, + min_size=min_size, + max_size=max_size, + max_queries=max_queries, + loop=loop, + connect=connect, + setup=setup, + init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, - **connect_kwargs) + **connect_kwargs, + ) diff --git a/tests/test_pool.py b/tests/test_pool.py index 2407b817..5bd70bd9 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -136,6 +136,12 @@ async def setup(con): async def test_pool_07(self): cons = set() + connect_called = 0 + + async def connect(*args, **kwargs): + nonlocal connect_called + connect_called += 1 + return await pg_connection.connect(*args, **kwargs) async def setup(con): if con._con not in cons: # `con` is `PoolConnectionProxy`. @@ -152,13 +158,26 @@ async def user(pool): raise RuntimeError('init was not called') async with self.create_pool(database='postgres', - min_size=2, max_size=5, + min_size=2, + max_size=5, + connect=connect, init=init, setup=setup) as pool: users = asyncio.gather(*[user(pool) for _ in range(10)]) await users self.assertEqual(len(cons), 5) + self.assertEqual(connect_called, 5) + + async def bad_connect(*args, **kwargs): + return 1 + + with self.assertRaisesRegex( + asyncpg.InterfaceError, + "expected pool connect callback to return an instance of " + "'asyncpg\\.connection\\.Connection', got 'int'" + ): + await self.create_pool(database='postgres', connect=bad_connect) async def test_pool_08(self): pool = await self.create_pool(database='postgres',