diff --git a/tox.ini b/tox.ini index 5b0b8b8..707b270 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,8 @@ envlist = pypy3-twisted15 [testenv] +passenv = DBTYPE +sitepackages = True deps = coverage twisted>=15.0, <16.0 diff --git a/twistar/dbconfig/base.py b/twistar/dbconfig/base.py index d2e45cb..1f66c9c 100644 --- a/twistar/dbconfig/base.py +++ b/twistar/dbconfig/base.py @@ -8,6 +8,7 @@ from twistar.registry import Registry from twistar.exceptions import ImaginaryTableError, CannotRefreshError from twistar.utils import joinWheres +from twistar.transaction import TransactionGuard class InteractionBase(object): @@ -25,7 +26,7 @@ class InteractionBase(object): def __init__(self): - self.txn = None + self.txnGuard = TransactionGuard() def logEncode(self, s, encoding='utf-8'): @@ -156,7 +157,6 @@ def _doselect(self, txn, q, args, tablename, one=False, cacheable=True): results.append(vals) return results - def insertArgsToString(self, vals): """ Convert C{{'name': value}} to an insert "values" string like C{"(%s,%s,%s)"}. @@ -327,8 +327,8 @@ def getSchema(self, tablename, txn=None): def runInteraction(self, interaction, *args, **kwargs): - if self.txn is not None: - return defer.succeed(interaction(self.txn, *args, **kwargs)) + if self.txnGuard.txn is not None: + return defer.succeed(interaction(self.txnGuard.txn, *args, **kwargs)) return Registry.DBPOOL.runInteraction(interaction, *args, **kwargs) @@ -345,8 +345,7 @@ def _doinsert(txn): if len(cols) == 0: raise ImaginaryTableError("Table %s does not exist." % tablename) vals = obj.toHash(cols, includeBlank=self.__class__.includeBlankInInsert, exclude=['id']) - self.insert(tablename, vals, txn) - obj.id = self.getLastInsertID(txn) + obj.id = self.insert(tablename, vals, txn) return obj return self.runInteraction(_doinsert) diff --git a/twistar/dbobject.py b/twistar/dbobject.py index cdfc7f2..1dd96fc 100644 --- a/twistar/dbobject.py +++ b/twistar/dbobject.py @@ -6,7 +6,8 @@ from twistar.registry import Registry from twistar.relationships import Relationship from twistar.exceptions import InvalidRelationshipError, DBObjectSaveError, ReferenceNotSavedError -from twistar.utils import createInstances, deferredDict, dictToWhere, transaction +from twistar.utils import createInstances, deferredDict, dictToWhere +from twistar.transaction import transaction from twistar.validation import Validator, Errors from BermiInflector.Inflector import Inflector diff --git a/twistar/tests/test_transactions.py b/twistar/tests/test_transactions.py index 8629f99..a0c329e 100644 --- a/twistar/tests/test_transactions.py +++ b/twistar/tests/test_transactions.py @@ -1,67 +1,382 @@ +import sys +from threading import Event + from twisted.trial import unittest -from twisted.internet.defer import inlineCallbacks +from twisted.internet import reactor +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, maybeDeferred, DeferredList +from twisted.python import threadable -from twistar.utils import transaction +from twistar.transaction import transaction, nested_transaction from twistar.exceptions import TransactionError -from utils import initDB, tearDownDB, Registry, Transaction +from twistar.tests.utils import initDB, tearDownDB, Registry, Transaction, DBTYPE + +class TransactionTests(unittest.TestCase): -class TransactionTest(unittest.TestCase): @inlineCallbacks def setUp(self): yield initDB(self) self.config = Registry.getConfig() - @inlineCallbacks def tearDown(self): - yield tearDownDB(self) + d_tearDown = tearDownDB(self) + delayed = reactor.callLater(2, d_tearDown.cancel) + try: + yield d_tearDown + delayed.cancel() + except: + print "db cleanup timed out" @inlineCallbacks - def test_findOrCreate(self): + def _assertRaises(self, deferred, *excTypes): + # required for downward compatibility + + excType = None + try: + yield deferred + except: + excType, exc, tb = sys.exc_info() + + msgFormat = "Deferred expected to fail with " + ", ".join(str(expType) for expType in excTypes) + "; instead got {}" + if not excType: + self.fail(msgFormat.format("Nothing")) + else: + self.failIf(not issubclass(excType, *excTypes), msgFormat.format(excType)) + + @transaction + def test_set_cfg_txn(txn, self): + """Verify that the transaction is actually being set correctly""" + self.assertIs(txn, Registry.getConfig().txnGuard.txn) + + with transaction() as txn2: + self.assertIs(txn2, Registry.getConfig().txnGuard.txn) + + self.assertIs(txn, Registry.getConfig().txnGuard.txn) + + @inlineCallbacks + def test_commit(self): + barrier = Event() + @transaction @inlineCallbacks - def interaction(txn): - yield Transaction.findOrCreate(name="a name") - yield Transaction.findOrCreate(name="a name") + def trans(txn): + self.assertFalse(threadable.isInIOThread(), "Transactions must not run in main thread") + + yield Transaction(name="TEST1").save() + yield Transaction(name="TEST2").save() + + barrier.wait() # wait here to delay commit + returnValue("return value") + + d = trans() - yield interaction() count = yield Transaction.count() - self.assertEqual(count, 1) + self.assertEqual(count, 0) + + barrier.set() + res = yield d + self.assertEqual(res, "return value") + count = yield Transaction.count() + self.assertEqual(count, 2) @inlineCallbacks - def test_doubleInsert(self): + def test_rollback(self): + barrier = Event() @transaction - def interaction(txn): - def finish(trans): - return Transaction(name="unique name").save() - return Transaction(name="unique name").save().addCallback(finish) + @inlineCallbacks + def trans(txn): + yield Transaction(name="TEST1").save() + yield Transaction(name="TEST2").save() - try: - yield interaction() - except TransactionError: - pass + barrier.wait() # wait here to delay commit + raise ZeroDivisionError() + + d = trans() + + barrier.set() + yield self._assertRaises(d, ZeroDivisionError) - # there should be no transaction records stored at all count = yield Transaction.count() self.assertEqual(count, 0) - @inlineCallbacks - def test_success(self): + def test_fake_nesting_commit(self): + barrier = Event() + threadIds = [] + + @transaction + @inlineCallbacks + def trans1(txn): + threadIds.append(threadable.getThreadID()) + yield Transaction(name="TEST1").save() @transaction - def interaction(txn): - def finish(trans): - return Transaction(name="unique name two").save() - return Transaction(name="unique name").save().addCallback(finish) + @inlineCallbacks + def trans2(txn): + threadIds.append(threadable.getThreadID()) + yield trans1() + yield Transaction(name="TEST2").save() + barrier.wait() # wait here to delay commit - result = yield interaction() - self.assertEqual(result.id, 2) + d = trans2() + + count = yield Transaction.count() + self.assertEqual(count, 0) + + barrier.set() + yield d + + self.assertEqual(threadIds[0], threadIds[1], "Nested transactions don't run in same thread") count = yield Transaction.count() self.assertEqual(count, 2) + + @inlineCallbacks + def test_fake_nesting_rollback(self): + barrier = Event() + + @transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST1").save() + txn.rollback() # should propagate to the root transaction + + @transaction + @inlineCallbacks + def trans2(txn): + yield Transaction(name="TEST2").save() + yield trans1() + + barrier.wait() # wait here to delay commit + + d = trans2() + + count = yield Transaction.count() + self.assertEqual(count, 0) + + barrier.set() + + yield d + + count = yield Transaction.count() + self.assertEqual(count, 0) + + @inlineCallbacks + def test_fake_nesting_ctxmgr(self): + @transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST1").save() + with transaction() as txn2: + yield Transaction(name="TEST2").save() + txn2.rollback() + + yield trans1() + + count = yield Transaction.count() + self.assertEqual(count, 0) + + @inlineCallbacks + def test_parallel_transactions(self): + if DBTYPE == "sqlite": + raise unittest.SkipTest("Parallel connections are not supported by sqlite") + + threadIds = [] + + # trans1 is supposed to pass, trans2 is supposed to fail due to unique constraint + # regarding synchronization: trans1 has to start INSERT before trans2, + # because otherwise it would wait for trans2 to finish due to postgres synchronization strategy + + on_trans1_insert = Event() + barrier1, barrier2 = Event(), Event() + + @transaction + @inlineCallbacks + def trans1(txn): + threadIds.append(threadable.getThreadID()) + yield Transaction(name="TEST1").save() + on_trans1_insert.set() + barrier1.wait() # wait here to delay commit) + + @transaction + @inlineCallbacks + def trans2(txn): + threadIds.append(threadable.getThreadID()) + on_trans1_insert.wait() + yield Transaction(name="TEST1").save() + barrier2.wait() # wait here to delay commit + + d1 = trans1() + d2 = trans2() + + # commit tran1, should pass: + barrier1.set() + yield d1 + + count = yield Transaction.count() + self.assertEqual(count, 1) + + # commit trans2: + barrier2.set() + + # should fail due to unique constraint violation + yield self._assertRaises(d2, Exception) + + self.assertNotEqual(threadIds[0], threadIds[1], "Parallel transactions don't run in different threads") + + count = yield Transaction.count() + self.assertEqual(count, 1) + + @inlineCallbacks + def test_parallel_massive(self): + # Make sure that everything works alright even when starting a massive amount of parallel transactions + if DBTYPE == "sqlite": + raise unittest.SkipTest("Parallel connections are not supported by sqlite") + + N = 100 + + @transaction + @inlineCallbacks + def trans(txn, i): + yield Transaction(name=str(i)).save() + if i % 2 == 1: + txn.rollback() + else: + txn.commit() + + deferreds = [trans(i) for i in range(N)] + + results = yield DeferredList(deferreds) + self.assertTrue(all(success for success, result in results)) + + objects = yield Transaction.all() + actual = sorted(int(obj.name) for obj in objects) + actual = [str(i) for i in actual] + expected = [str(i) for i in range(0, N, 2)] + + self.assertEquals(actual, expected) + + @inlineCallbacks + def test_savepoints_commit(self): + if DBTYPE == "sqlite": + raise unittest.SkipTest("SAVEPOINT acts weird with sqlite, needs further inspection.") + + @transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST1").save() + with nested_transaction(): + yield Transaction(name="TEST2").save() + yield Transaction(name="TEST3").save() + + yield trans1() + objects = yield Transaction.all() + self.assertEqual([obj.name for obj in objects], ["TEST1", "TEST2", "TEST3"]) + + @inlineCallbacks + def test_savepoints_rollback(self): + if DBTYPE == "sqlite": + raise unittest.SkipTest("SAVEPOINT acts weird with sqlite, needs further inspection.") + + @transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST1").save() + with nested_transaction() as txn2: + yield Transaction(name="TEST2").save() + txn2.rollback() + yield Transaction(name="TEST3").save() + + yield trans1() + objects = yield Transaction.all() + self.assertEqual([obj.name for obj in objects], ["TEST1", "TEST3"]) + + @inlineCallbacks + def test_savepoints_mixed(self): + if DBTYPE == "sqlite": + raise unittest.SkipTest("SAVEPOINT acts weird with sqlite, needs further inspection.") + + @nested_transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST3").save() + with transaction() as txn2: + yield Transaction(name="TEST4").save() + txn2.rollback() + + @transaction + @inlineCallbacks + def trans2(txn): + yield Transaction(name="TEST1").save() + with nested_transaction(): + yield Transaction(name="TEST2").save() + yield trans1() + yield Transaction(name="TEST5").save() + + yield trans2() + objects = yield Transaction.all() + self.assertEqual([obj.name for obj in objects], ["TEST1", "TEST2", "TEST5"]) + + @inlineCallbacks + def test_sanity_checks(self): + # Already rollbacked/commited: + @transaction + def trans1(txn): + txn.rollback() + txn.commit() + + yield self._assertRaises(trans1(), TransactionError) + + # With nesting: + @transaction + def trans2(txn): + with transaction() as txn2: + txn2.rollback() + txn.commit() + + yield self._assertRaises(trans2(), TransactionError) + + # Error if started in main thread: + yield self._assertRaises(maybeDeferred(transaction), TransactionError) + + # But shouldn't fail if called with thread_check=False + transaction(thread_check=False).rollback() + + # Error if rollbacked/commited in another thread: + main_thread_d = Deferred() + on_cb_added = Event() + on_callbacked = Event() + + @transaction + def trans3(txn): + def from_mainthread(do_commit): + if do_commit: + txn.commit() + else: + txn.rollback() + + main_thread_d.addCallback(from_mainthread) + on_cb_added.set() + on_callbacked.wait() # don't return (which would cause commit) until main thread executed callbacks + return main_thread_d # deferred will fail if from_mainthread() raised an Exception + + d = trans3() + on_cb_added.wait() # we need to wait for the callback to be registered otherwise it would be executed in db thread + main_thread_d.callback(True) # will commit the transaction in main thread + on_callbacked.set() + yield self._assertRaises(d, TransactionError) + + main_thread_d = Deferred() + on_cb_added.clear() + on_callbacked.clear() + + d = trans3() + on_cb_added.wait() + main_thread_d.callback(False) # will rollback the transaction in main thread + on_callbacked.set() + yield self._assertRaises(d, TransactionError) diff --git a/twistar/transaction.py b/twistar/transaction.py new file mode 100644 index 0000000..67397c0 --- /dev/null +++ b/twistar/transaction.py @@ -0,0 +1,239 @@ +import threading +import functools + +from twisted.enterprise import adbapi +from twisted.internet.defer import maybeDeferred, Deferred +from twisted.python import threadable + +from twistar.registry import Registry +from twistar.exceptions import TransactionError + + +class TransactionGuard(threading.local): + + def __init__(self): + self._txn = None + + @property + def txn(self): + return self._txn + + @txn.setter + def txn(self, txn): + self._txn = txn + + +class _Transaction(object): + """Mostly borrowed from sqlalchemy and adapted to adbapi""" + + def __init__(self, parent, thread_check=True): + # Transactions must be started in db thread unless explicitely permitted + if thread_check and threading.current_thread() not in Registry.DBPOOL.threadpool.threads: + raise TransactionError("Transaction must only be started in a db pool thread") + + if parent is None: + self._root = self + else: + self._root = parent._root + + self._actual_parent = parent + self.is_active = True + self._threadId = threadable.getThreadID() + self._savepoint_seq = 0 + + if not self._parent.is_active: + raise TransactionError("Parent transaction is inactive") + + Registry.getConfig().txnGuard.txn = self + + @property + def _parent(self): + return self._actual_parent or self + + def _assertCorrectThread(self): + if threadable.getThreadID() != self._threadId: + raise TransactionError("Tried to rollback a transaction from a different thread.\n" + "Make sure that you properly use blockingCallFromThread() and\n" + "that you don't add callbacks to Deferreds which get resolved from another thread.") + + def rollback(self): + self._assertCorrectThread() + + if not self._parent.is_active: + return + + Registry.getConfig().txnGuard.txn = self._actual_parent + self._do_rollback() + self.is_active = False + + def _do_rollback(self): + self._parent.rollback() + + def commit(self): + self._assertCorrectThread() + + if not self._parent.is_active: + raise TransactionError("This transaction is inactive") + + Registry.getConfig().txnGuard.txn = self._actual_parent + self._do_commit() + self.is_active = False + + def _do_commit(self): + pass + + def __enter__(self): + return self + + def __exit__(self, excType, exc, traceback): + if excType is not None and issubclass(excType, Exception): + self.rollback() + elif self.is_active: + try: + self.commit() + except: + self.rollback() + raise + + def __getattr__(self, key): + return getattr(self._root, key) + + +class _RootTransaction(adbapi.Transaction, _Transaction): + + def __init__(self, pool, connection, thread_check=True): + adbapi.Transaction.__init__(self, pool, connection) + _Transaction.__init__(self, None, thread_check=thread_check) + + def close(self): + # don't set to None but errorout on subsequent access + self._cursor.close() + + def _do_rollback(self): + if self.is_active: + self._connection.rollback() + self.close() + + def _do_commit(self): + if self.is_active: + self._connection.commit() + self.close() + + def __getattr__(self, key): + return getattr(self._cursor, key) + + +class _SavepointTransaction(_Transaction): + + def __init__(self, parent, thread_check=True): + super(_SavepointTransaction, self).__init__(parent, thread_check=thread_check) + + self._root._savepoint_seq += 1 + self._name = "twistar_savepoint_{}".format(self._root._savepoint_seq) + + self.execute("SAVEPOINT {}".format(self._name)) + + def _do_rollback(self): + if self.is_active: + self.execute("ROLLBACK TO SAVEPOINT {}".format(self._name)) + + def _do_commit(self): + if self.is_active: + self.execute("RELEASE SAVEPOINT {}".format(self._name)) + + +def _transaction_dec(func, create_transaction): + + def _runTransaction(*args, **kwargs): + txn = create_transaction() + + def on_succcess(result): + if txn.is_active: + try: + txn.commit() + except: + txn.rollback() + return result + + def on_error(fail): + if txn.is_active: + txn.rollback() + + return fail + + d = maybeDeferred(func, txn, *args, **kwargs) + d.addCallbacks(on_succcess, on_error) + d.addErrback(on_error) + return d + + @functools.wraps(func) + def wrapper(*args, **kwargs): + d = None # declare here so that on_result can access it + + def on_result(success, txn_deferred): + from twisted.internet import reactor + txn_deferred.addCallbacks(lambda res: reactor.callFromThread(d.callback, res), + lambda fail: reactor.callFromThread(d.errback, fail)) + + if threadable.isInIOThread(): + d = Deferred() + thpool = Registry.DBPOOL.threadpool + thpool.callInThreadWithCallback(on_result, _runTransaction, *args, **kwargs) + return d + else: + # we are already in a db thread, so just execute the transaction + return _runTransaction(*args, **kwargs) + + return wrapper + + +def transaction(func=None, nested=False, thread_check=True): + """Starts a new transaction. + + A Transaction object returned by this function can be used as a context manager, + which will atomatically be commited or rolledback if an exception is raised. + + Transactions must only be used in db threads. This behaviour can be overriden by setting the + 'thread_check' to False, allowing transactions to be started in arbitrary threads which is + useful to e.g simplify testcases. + + If this function is used as decorator, the decorated function will be executed in a db thread and + gets the Transaction passed as first argument. Decorated functions are allowed to return Deferreds. + E.g: + @transaction + def someFunc(txn, param1): + # Runs in a db thread + + d = someFunc(1) # will be calledback (in mainthread) when someFunc returns + + You have to make sure, that you use blockingCallFromThread() or use synchronization if you need to + interact with code which runs in the mainthread. Also care has to be taken when waiting for Deferreds. + You must assure that the callbacks will be invoked from the db thread. + + Per default transactions can be nested: Commiting such a "nested" transaction will simply do nothing, + but a rollback on it will rollback the outermost transaction. This allow creation of functions which will + either create a new transaction or will participate in an already ongoing tranaction which is handy for library code. + + SAVEPOINT transactions can be used by either setting the 'nested' flag to true or by calling the 'nested_transaction' function. + """ + if nested and Registry.DBPOOL.dbapi.__name__ == "sqlite3": + # needs some modification on our side, see: + # http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + raise NotImplementedError("sqlite currently not supported") + + if func is None: + conn_pool = Registry.DBPOOL + cfg = Registry.getConfig() + + if cfg.txnGuard.txn is None: + conn = conn_pool.connectionFactory(conn_pool) + return _RootTransaction(conn_pool, conn, thread_check=thread_check) + elif nested: + return _SavepointTransaction(cfg.txnGuard.txn, thread_check=thread_check) + else: + return _Transaction(cfg.txnGuard.txn, thread_check=thread_check) + else: + return _transaction_dec(func, functools.partial(transaction, nested=nested, thread_check=thread_check)) + + +nested_transaction = functools.partial(transaction, nested=True) diff --git a/twistar/utils.py b/twistar/utils.py index c2efbdb..dceb4b8 100644 --- a/twistar/utils.py +++ b/twistar/utils.py @@ -8,30 +8,6 @@ from twistar.exceptions import TransactionError -def transaction(interaction): - """ - A decorator to wrap any code in a transaction. If any exceptions are raised, all modifications - are rolled back. The function that is decorated should accept at least one argument, which is - the transaction (in case you want to operate directly on it). - """ - def _transaction(txn, args, kwargs): - config = Registry.getConfig() - config.txn = txn - # get the result of the functions *synchronously*, since this is in a transaction - try: - result = threads.blockingCallFromThread(reactor, interaction, txn, *args, **kwargs) - config.txn = None - return result - except Exception, e: - config.txn = None - raise TransactionError(str(e)) - - def wrapper(*args, **kwargs): - return Registry.DBPOOL.runInteraction(_transaction, args, kwargs) - - return wrapper - - def createInstances(props, klass): """ Create an instance of C{list} of instances of a given class