forked from aio-libs/aiomysql
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added support for sqlalchemy default parameters aio-libs#455
- Loading branch information
Ганжин Михаил
committed
Dec 11, 2019
1 parent
b16e5bd
commit 160a4c5
Showing
2 changed files
with
135 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import asyncio | ||
import datetime | ||
import os | ||
import unittest | ||
from unittest import mock | ||
|
||
import sqlalchemy as sa | ||
|
||
import aiomysql.sa | ||
from aiomysql import connect | ||
|
||
meta = sa.MetaData() | ||
table = sa.Table('sa_tbl', meta, | ||
sa.Column('id', sa.Integer, nullable=False, primary_key=True), | ||
sa.Column('string_length', sa.Integer, | ||
default=sa.func.length('qwerty')), | ||
sa.Column('number', sa.Integer, default=100, nullable=False), | ||
sa.Column('description', sa.String(255), nullable=False, | ||
default='default test'), | ||
sa.Column('created_at', sa.DateTime, | ||
default=datetime.datetime.now), | ||
sa.Column('enabled', sa.Boolean, default=True)) | ||
|
||
|
||
class TestSAConnection(unittest.TestCase): | ||
def setUp(self): | ||
self.loop = asyncio.new_event_loop() | ||
asyncio.set_event_loop(None) | ||
self.host = os.environ.get('MYSQL_HOST', 'localhost') | ||
self.port = int(os.environ.get('MYSQL_PORT', 3306)) | ||
self.user = os.environ.get('MYSQL_USER', 'root') | ||
self.db = os.environ.get('MYSQL_DB', 'test_pymysql') | ||
self.password = os.environ.get('MYSQL_PASSWORD', '') | ||
|
||
def tearDown(self): | ||
self.loop.close() | ||
|
||
async def connect(self, **kwargs): | ||
conn = await connect(db=self.db, | ||
user=self.user, | ||
password=self.password, | ||
host=self.host, | ||
loop=self.loop, | ||
port=self.port, | ||
**kwargs) | ||
await conn.autocommit(True) | ||
cur = await conn.cursor() | ||
await cur.execute("DROP TABLE IF EXISTS sa_tbl") | ||
await cur.execute("CREATE TABLE sa_tbl " | ||
"(id integer, string_length integer, number integer," | ||
" description VARCHAR(255), created_at DATETIME(6), " | ||
"enabled TINYINT)") | ||
|
||
await cur._connection.commit() | ||
# await cur.close() | ||
engine = mock.Mock() | ||
engine.dialect = aiomysql.sa.engine._dialect | ||
return aiomysql.sa.SAConnection(conn, engine) | ||
|
||
def test_default_fields(self): | ||
async def go(): | ||
conn = await self.connect() | ||
await conn.execute(table.insert().values()) | ||
|
||
res = await conn.execute(table.select()) | ||
row = await res.fetchone() | ||
self.assertEqual(row.string_length, 6) | ||
self.assertEqual(row.number, 100) | ||
self.assertEqual(row.description, 'default test') | ||
self.assertEqual(row.enabled, True) | ||
self.assertEqual(type(row.created_at), datetime.datetime) | ||
|
||
self.loop.run_until_complete(go()) | ||
|
||
def test_default_fields_isnull(self): | ||
async def go(): | ||
conn = await self.connect() | ||
created_at = None | ||
enabled = False | ||
await conn.execute(table.insert().values( | ||
enabled=enabled, | ||
created_at=created_at, | ||
)) | ||
|
||
res = await conn.execute(table.select()) | ||
row = await res.fetchone() | ||
self.assertEqual(row.number, 100) | ||
self.assertEqual(row.string_length, 6) | ||
self.assertEqual(row.description, 'default test') | ||
self.assertEqual(row.enabled, enabled) | ||
self.assertEqual(row.created_at, created_at) | ||
|
||
self.loop.run_until_complete(go()) | ||
|
||
def test_default_fields_edit(self): | ||
async def go(): | ||
conn = await self.connect() | ||
created_at = datetime.datetime.now() | ||
description = 'new descr' | ||
enabled = False | ||
number = 111 | ||
await conn.execute(table.insert().values( | ||
description=description, | ||
enabled=enabled, | ||
created_at=created_at, | ||
number=number, | ||
)) | ||
|
||
res = await conn.execute(table.select()) | ||
row = await res.fetchone() | ||
self.assertEqual(row.number, number) | ||
self.assertEqual(row.string_length, 6) | ||
self.assertEqual(row.description, description) | ||
self.assertEqual(row.enabled, enabled) | ||
self.assertEqual(row.created_at, created_at) | ||
|
||
self.loop.run_until_complete(go()) |