Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace yield from with await... #283

Merged
merged 9 commits into from
May 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 113 additions & 150 deletions aiomysql/connection.py

Large diffs are not rendered by default.

143 changes: 60 additions & 83 deletions aiomysql/cursors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import re
import warnings

Expand All @@ -8,7 +7,7 @@
NotSupportedError, ProgrammingError)

from .log import logger
from .utils import PY_35, create_future
from .utils import create_future


# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18
Expand Down Expand Up @@ -149,14 +148,13 @@ def closed(self):
"""
return True if not self._connection else False

@asyncio.coroutine
def close(self):
async def close(self):
"""Closing a cursor just exhausts all remaining data."""
conn = self._connection
if conn is None:
return
try:
while (yield from self.nextset()):
while (await self.nextset()):
pass
finally:
self._connection = None
Expand All @@ -179,17 +177,16 @@ def setinputsizes(self, *args):
def setoutputsizes(self, *args):
"""Does nothing, required by DB API."""

@asyncio.coroutine
def nextset(self):
async def nextset(self):
"""Get the next query set"""
conn = self._get_db()
current_result = self._result
if current_result is None or current_result is not conn._result:
return
if not current_result.has_next:
return
yield from conn.next_result()
yield from self._do_get_result()
await conn.next_result()
await self._do_get_result()
return True

def _escape_args(self, args, conn):
Expand All @@ -215,8 +212,7 @@ def mogrify(self, query, args=None):
query = query % self._escape_args(args, conn)
return query

@asyncio.coroutine
def execute(self, query, args=None):
async def execute(self, query, args=None):
"""Executes the given operation

Executes the given operation substituting any markers with
Expand All @@ -231,21 +227,20 @@ def execute(self, query, args=None):
"""
conn = self._get_db()

while (yield from self.nextset()):
while (await self.nextset()):
pass

if args is not None:
query = query % self._escape_args(args, conn)

yield from self._query(query)
await self._query(query)
self._executed = query
if self._echo:
logger.info(query)
logger.info("%r", args)
return self._rowcount

@asyncio.coroutine
def executemany(self, query, args):
async def executemany(self, query, args):
"""Execute the given operation multiple times

The executemany() method will execute the operation iterating
Expand All @@ -259,7 +254,7 @@ def executemany(self, query, args):
('John', '555-003')
]
stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')"
yield from cursor.executemany(stmt, data)
await cursor.executemany(stmt, data)

INSERT or REPLACE statements are optimized by batching the data,
that is using the MySQL multiple rows syntax.
Expand All @@ -280,20 +275,19 @@ def executemany(self, query, args):
q_values = m.group(2).rstrip()
q_postfix = m.group(3) or ''
assert q_values[0] == '(' and q_values[-1] == ')'
return (yield from self._do_execute_many(
return (await self._do_execute_many(
q_prefix, q_values, q_postfix, args, self.max_stmt_length,
self._get_db().encoding))
else:
rows = 0
for arg in args:
yield from self.execute(query, arg)
await self.execute(query, arg)
rows += self._rowcount
self._rowcount = rows
return self._rowcount

@asyncio.coroutine
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length,
encoding):
async def _do_execute_many(self, prefix, values, postfix, args,
max_stmt_length, encoding):
conn = self._get_db()
escape = self._escape_args
if isinstance(prefix, str):
Expand All @@ -312,19 +306,18 @@ def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length,
if isinstance(v, str):
v = v.encode(encoding, 'surrogateescape')
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
r = yield from self.execute(sql + postfix)
r = await self.execute(sql + postfix)
rows += r
sql = bytearray(prefix)
else:
sql += b','
sql += v
r = yield from self.execute(sql + postfix)
r = await self.execute(sql + postfix)
rows += r
self._rowcount = rows
return rows

@asyncio.coroutine
def callproc(self, procname, args=()):
async def callproc(self, procname, args=()):
"""Execute stored procedure procname with args

Compatibility warning: PEP-249 specifies that any modified
Expand Down Expand Up @@ -357,12 +350,12 @@ def callproc(self, procname, args=()):

for index, arg in enumerate(args):
q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
yield from self._query(q)
yield from self.nextset()
await self._query(q)
await self.nextset()

_args = ','.join('@_%s_%d' % (procname, i) for i in range(len(args)))
q = "CALL %s(%s)" % (procname, _args)
yield from self._query(q)
await self._query(q)
self._executed = q
return args

Expand Down Expand Up @@ -454,15 +447,13 @@ def scroll(self, value, mode='relative'):
fut.set_result(None)
return fut

@asyncio.coroutine
def _query(self, q):
async def _query(self, q):
conn = self._get_db()
self._last_executed = q
yield from conn.query(q)
yield from self._do_get_result()
await conn.query(q)
await self._do_get_result()

@asyncio.coroutine
def _do_get_result(self):
async def _do_get_result(self):
conn = self._get_db()
self._rownumber = 0
self._result = result = conn._result
Expand All @@ -472,13 +463,12 @@ def _do_get_result(self):
self._rows = result.rows

if result.warning_count > 0:
yield from self._show_warnings(conn)
await self._show_warnings(conn)

@asyncio.coroutine
def _show_warnings(self, conn):
async def _show_warnings(self, conn):
if self._result and self._result.has_next:
return
ws = yield from conn.show_warnings()
ws = await conn.show_warnings()
if ws is None:
return
for w in ws:
Expand All @@ -496,36 +486,30 @@ def _show_warnings(self, conn):
ProgrammingError = ProgrammingError
NotSupportedError = NotSupportedError

if PY_35: # pragma: no branch
@asyncio.coroutine
def __aiter__(self):
return self
async def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
ret = yield from self.fetchone()
if ret is not None:
return ret
else:
raise StopAsyncIteration # noqa
async def __anext__(self):
ret = await self.fetchone()
if ret is not None:
return ret
else:
raise StopAsyncIteration # noqa

@asyncio.coroutine
def __aenter__(self):
return self
async def __aenter__(self):
return self

@asyncio.coroutine
def __aexit__(self, exc_type, exc_val, exc_tb):
yield from self.close()
return
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
return


class _DictCursorMixin:
# You can override this to use OrderedDict or other dict-like types.
dict_type = dict

@asyncio.coroutine
def _do_get_result(self):
yield from super()._do_get_result()
async def _do_get_result(self):
await super()._do_get_result()
fields = []
if self._description:
for f in self._result.fields:
Expand Down Expand Up @@ -563,61 +547,55 @@ class SSCursor(Cursor):
possible to scroll backwards, as only the current row is held in memory.
"""

@asyncio.coroutine
def close(self):
async def close(self):
conn = self._connection
if conn is None:
return

if self._result is not None and self._result is conn._result:
yield from self._result._finish_unbuffered_query()
await self._result._finish_unbuffered_query()

try:
while (yield from self.nextset()):
while (await self.nextset()):
pass
finally:
self._connection = None

@asyncio.coroutine
def _query(self, q):
async def _query(self, q):
conn = self._get_db()
self._last_executed = q
yield from conn.query(q, unbuffered=True)
yield from self._do_get_result()
await conn.query(q, unbuffered=True)
await self._do_get_result()
return self._rowcount

@asyncio.coroutine
def _read_next(self):
async def _read_next(self):
"""Read next row """
row = yield from self._result._read_rowdata_packet_unbuffered()
row = await self._result._read_rowdata_packet_unbuffered()
row = self._conv_row(row)
return row

@asyncio.coroutine
def fetchone(self):
async def fetchone(self):
""" Fetch next row """
self._check_executed()
row = yield from self._read_next()
row = await self._read_next()
if row is None:
return
self._rownumber += 1
return row

@asyncio.coroutine
def fetchall(self):
async def fetchall(self):
"""Fetch all, as per MySQLdb. Pretty useless for large queries, as
it is buffered.
"""
rows = []
while True:
row = yield from self.fetchone()
row = await self.fetchone()
if row is None:
break
rows.append(row)
return rows

@asyncio.coroutine
def fetchmany(self, size=None):
async def fetchmany(self, size=None):
"""Returns the next set of rows of a query result, returning a
list of tuples. When no more rows are available, it returns an
empty list.
Expand All @@ -634,15 +612,14 @@ def fetchmany(self, size=None):

rows = []
for i in range(size):
row = yield from self._read_next()
row = await self._read_next()
if row is None:
break
rows.append(row)
self._rownumber += 1
return rows

@asyncio.coroutine
def scroll(self, value, mode='relative'):
async def scroll(self, value, mode='relative'):
"""Scroll the cursor in the result set to a new position
according to mode . Same as :meth:`Cursor.scroll`, but move cursor
on server side one by one row. If you want to move 20 rows forward
Expand All @@ -661,7 +638,7 @@ def scroll(self, value, mode='relative'):
"by this cursor")

for _ in range(value):
yield from self._read_next()
await self._read_next()
self._rownumber += value
elif mode == 'absolute':
if value < self._rownumber:
Expand All @@ -670,7 +647,7 @@ def scroll(self, value, mode='relative'):

end = value - self._rownumber
for _ in range(end):
yield from self._read_next()
await self._read_next()
self._rownumber = value
else:
raise ProgrammingError("unknown scroll mode %s" % mode)
Expand Down
Loading