diff --git a/aiomysql/sa/engine.py b/aiomysql/sa/engine.py index b6282f3f..243e5001 100644 --- a/aiomysql/sa/engine.py +++ b/aiomysql/sa/engine.py @@ -12,11 +12,29 @@ try: from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql + from sqlalchemy.dialects.mysql.mysqldb import MySQLCompiler_mysqldb except ImportError: # pragma: no cover raise ImportError('aiomysql.sa requires sqlalchemy') +class MySQLCompiler_pymysql(MySQLCompiler_mysqldb): + def construct_params(self, params=None, _group_number=None, _check=True): + pd = super().construct_params(params, _group_number, _check) + + for column in self.prefetch: + pd[column.key] = self._exec_default(column.default) + + return pd + + def _exec_default(self, default): + if default.is_callable: + return default.arg(self.dialect) + else: + return default.arg + + _dialect = MySQLDialect_pymysql(paramstyle='pyformat') +_dialect.statement_compiler = MySQLCompiler_pymysql _dialect.default_paramstyle = 'pyformat' diff --git a/tests/sa/test_sa_default.py b/tests/sa/test_sa_default.py new file mode 100644 index 00000000..14974508 --- /dev/null +++ b/tests/sa/test_sa_default.py @@ -0,0 +1,117 @@ +import asyncio +import datetime +import os +import unittest +from unittest import mock + +import sqlalchemy as sa + +import aiomysql.sa +from aiomysql import connect + +meta = sa.MetaData() +table = sa.Table('sa_tbl', meta, + sa.Column('id', sa.Integer, nullable=False, primary_key=True), + sa.Column('string_length', sa.Integer, + default=sa.func.length('qwerty')), + sa.Column('number', sa.Integer, default=100, nullable=False), + sa.Column('description', sa.String(255), nullable=False, + default='default test'), + sa.Column('created_at', sa.DateTime, + default=datetime.datetime.now), + sa.Column('enabled', sa.Boolean, default=True)) + + +class TestSAConnection(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + self.host = os.environ.get('MYSQL_HOST', 'localhost') + self.port = int(os.environ.get('MYSQL_PORT', 3306)) + self.user = os.environ.get('MYSQL_USER', 'root') + self.db = os.environ.get('MYSQL_DB', 'test_pymysql') + self.password = os.environ.get('MYSQL_PASSWORD', '') + + def tearDown(self): + self.loop.close() + + async def connect(self, **kwargs): + conn = await connect(db=self.db, + user=self.user, + password=self.password, + host=self.host, + loop=self.loop, + port=self.port, + **kwargs) + await conn.autocommit(True) + cur = await conn.cursor() + await cur.execute("DROP TABLE IF EXISTS sa_tbl") + await cur.execute("CREATE TABLE sa_tbl " + "(id integer, string_length integer, number integer," + " description VARCHAR(255), created_at DATETIME(6), " + "enabled TINYINT)") + + await cur._connection.commit() + # await cur.close() + engine = mock.Mock() + engine.dialect = aiomysql.sa.engine._dialect + return aiomysql.sa.SAConnection(conn, engine) + + def test_default_fields(self): + async def go(): + conn = await self.connect() + await conn.execute(table.insert().values()) + + res = await conn.execute(table.select()) + row = await res.fetchone() + self.assertEqual(row.string_length, 6) + self.assertEqual(row.number, 100) + self.assertEqual(row.description, 'default test') + self.assertEqual(row.enabled, True) + self.assertEqual(type(row.created_at), datetime.datetime) + + self.loop.run_until_complete(go()) + + def test_default_fields_isnull(self): + async def go(): + conn = await self.connect() + created_at = None + enabled = False + await conn.execute(table.insert().values( + enabled=enabled, + created_at=created_at, + )) + + res = await conn.execute(table.select()) + row = await res.fetchone() + self.assertEqual(row.number, 100) + self.assertEqual(row.string_length, 6) + self.assertEqual(row.description, 'default test') + self.assertEqual(row.enabled, enabled) + self.assertEqual(row.created_at, created_at) + + self.loop.run_until_complete(go()) + + def test_default_fields_edit(self): + async def go(): + conn = await self.connect() + created_at = datetime.datetime.now() + description = 'new descr' + enabled = False + number = 111 + await conn.execute(table.insert().values( + description=description, + enabled=enabled, + created_at=created_at, + number=number, + )) + + res = await conn.execute(table.select()) + row = await res.fetchone() + self.assertEqual(row.number, number) + self.assertEqual(row.string_length, 6) + self.assertEqual(row.description, description) + self.assertEqual(row.enabled, enabled) + self.assertEqual(row.created_at, created_at) + + self.loop.run_until_complete(go())