diff --git a/krake/krake/api/app.py b/krake/krake/api/app.py index 12a007ebe..cf3f0c13f 100644 --- a/krake/krake/api/app.py +++ b/krake/krake/api/app.py @@ -21,7 +21,10 @@ """ import logging import ssl +from functools import partial + from aiohttp import web, ClientSession +from krake.api.database import Session from krake.data.core import RoleBinding from . import __version__ as version @@ -90,11 +93,7 @@ def create_app(config): middlewares=[ middlewares.error_log(), authentication, - middlewares.database( - host=config.etcd.host, - port=config.etcd.port, - retry=config.etcd.retry_transactions, - ), + middlewares.retry_transaction(retry=config.etcd.retry_transactions), ], ) app["config"] = config @@ -103,6 +102,9 @@ def create_app(config): # Cleanup contexts app.cleanup_ctx.append(http_session) + app.cleanup_ctx.append( + partial(db_session, host=config.etcd.host, port=config.etcd.port) + ) # Routes app.add_routes(routes) @@ -113,6 +115,23 @@ def create_app(config): return app +async def db_session(app, host, port): + """Async generator creating an database :class:`krake.api.database.Session` that can + be used by other components (middleware, route handlers) or the requests. The + database session is available under the ``db`` key of the application. + + This function should be used as cleanup context (see + :attr:`aiohttp.web.Application.cleanup_ctx`). + + Args: + app (aiohttp.web.Application): Web application + + """ + async with Session(host=host, port=port) as session: + app["db"] = session + yield + + async def http_session(app): """Async generator creating an :class:`aiohttp.ClientSession` HTTP session that can be used by other components (middleware, route handlers). The HTTP diff --git a/krake/krake/api/helpers.py b/krake/krake/api/helpers.py index ba84620c3..2f21f1f54 100644 --- a/krake/krake/api/helpers.py +++ b/krake/krake/api/helpers.py @@ -41,8 +41,8 @@ def json_error(exc, content): def session(request): """Load the database session for a given aiohttp request - Internally, it just returns the value that was assigned by - func:`krake.middlewares.database`. + Internally, it just returns the value that was given as cleanup context by + func:`krake.api.app.db_session`. Args: request (aiohttp.web.Request): HTTP request @@ -50,7 +50,7 @@ def session(request): Returns: krake.database.Session: Database session for the given request """ - return request["db"] + return request.app["db"] class Heartbeat(object): diff --git a/krake/krake/api/middlewares.py b/krake/krake/api/middlewares.py index 0dc4ea0f5..3cc06cb17 100644 --- a/krake/krake/api/middlewares.py +++ b/krake/krake/api/middlewares.py @@ -4,50 +4,42 @@ from aiohttp import web from krake.api.helpers import HttpReason, HttpReasonCode -from .database import Session, TransactionError +from .database import TransactionError -def database(host, port, retry=1): - """Middleware factory for per-request etcd database sessions and - transaction error handling. +def retry_transaction(retry=1): + """Middleware factory for transaction error handling. - If an :class:`.database.TransactionError` occurs, the request handler is - retried for the specified number of times. If the transaction error - persists, a *409 Conflict* HTTP exception is raised. + If an :class:`.database.TransactionError` occurs, the request handler is retried for + the specified number of times. If the transaction error persists, a *409 Conflict* + HTTP exception is raised. Args: - host (str): Host of the etcd server - port (int): TCP port of the etcd server - retry (int, optional): Number of retries if a transaction error - occurs. + retry (int, optional): Number of retries if a transaction error occurs. Returns: - aiohttp middleware injecting an etcd database session into each HTTP - request and handling transaction errors. + coroutine: aiohttp middleware handling transaction errors. """ # TODO: Maybe we can share the TCP connection pool across all HTTP # handlers (like for SQLAlchemy engines) @web.middleware - async def database_middleware(request, handler): - async with Session(host=host, port=port) as session: - request["db"] = session - - for _ in range(retry + 1): - try: - return await handler(request) - except TransactionError as err: - request.app.logger.warn("Transaction failed (%s)", err) - - reason = HttpReason( - reason="Concurrent writes to database", - code=HttpReasonCode.TRANSACTION_ERROR, - ) - raise web.HTTPConflict( - text=json.dumps(reason.serialize()), content_type="application/json" - ) - - return database_middleware + async def retry_transaction_middleware(request, handler): + for _ in range(retry + 1): + try: + return await handler(request) + except TransactionError as err: + request.app.logger.warn("Transaction failed (%s)", err) + + reason = HttpReason( + reason="Concurrent writes to database", + code=HttpReasonCode.TRANSACTION_ERROR, + ) + raise web.HTTPConflict( + text=json.dumps(reason.serialize()), content_type="application/json" + ) + + return retry_transaction_middleware def error_log():