From 313d221d5c012fb38d9ef847fddee2ac61541c1d Mon Sep 17 00:00:00 2001 From: elBroom Date: Fri, 18 Jan 2019 20:28:58 +0300 Subject: [PATCH] Allow to set expand Cursor --- aiomysql/sa/engine.py | 22 +++++++++++++++------- tests/test_async_with.py | 15 ++++++++++++++- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/aiomysql/sa/engine.py b/aiomysql/sa/engine.py index fc61121d..b6282f3f 100644 --- a/aiomysql/sa/engine.py +++ b/aiomysql/sa/engine.py @@ -6,7 +6,8 @@ from .connection import SAConnection from .exc import InvalidRequestError, ArgumentError from ..utils import _PoolContextManager, _PoolAcquireContextManager -from ..cursors import Cursor +from ..cursors import ( + Cursor, DeserializationCursor, DictCursor, SSCursor, SSDictCursor) try: @@ -26,16 +27,23 @@ def create_engine(minsize=1, maxsize=10, loop=None, Returns Engine instance with embedded connection pool. - The pool has *minsize* opened connections to PostgreSQL server. + The pool has *minsize* opened connections to MySQL server. """ + deprecated_cursor_classes = [ + DeserializationCursor, DictCursor, SSCursor, SSDictCursor, + ] + + cursorclass = kwargs.get('cursorclass', Cursor) + if not issubclass(cursorclass, Cursor) or any( + issubclass(cursorclass, cursor_class) + for cursor_class in deprecated_cursor_classes + ): + raise ArgumentError('SQLAlchemy engine does not support ' + 'this cursor class') + coro = _create_engine(minsize=minsize, maxsize=maxsize, loop=loop, dialect=dialect, pool_recycle=pool_recycle, compiled_cache=compiled_cache, **kwargs) - compatible_cursor_classes = [Cursor] - # Without provided kwarg, default is default cursor from Connection class - if kwargs.get('cursorclass', Cursor) not in compatible_cursor_classes: - raise ArgumentError('SQLAlchemy engine does not support ' - 'this cursor class') return _EngineContextManager(coro) diff --git a/tests/test_async_with.py b/tests/test_async_with.py index 0be32d32..6aa265e5 100644 --- a/tests/test_async_with.py +++ b/tests/test_async_with.py @@ -3,7 +3,7 @@ import aiomysql import pytest -from aiomysql import sa, create_pool, DictCursor +from aiomysql import sa, create_pool, DictCursor, Cursor from sqlalchemy import MetaData, Table, Column, Integer, String @@ -276,3 +276,16 @@ async def test_incompatible_cursor_fails(loop, mysql_params): msg = 'SQLAlchemy engine does not support this cursor class' assert str(ctx.value) == msg + + +@pytest.mark.run_loop +async def test_compatible_cursor_correct(loop, mysql_params): + class SubCursor(Cursor): + pass + + mysql_params['cursorclass'] = SubCursor + async with sa.create_engine(loop=loop, **mysql_params) as engine: + async with engine.acquire() as conn: + # check not raise sa.ArgumentError exception + pass + assert conn.closed