Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve exception handling for concurrent execution #12109

Merged
merged 5 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12109.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve exception handling for concurrent execution.
4 changes: 2 additions & 2 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder, json_encoder, log_failure
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
Expand Down
4 changes: 3 additions & 1 deletion synapse/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]:


def unwrapFirstError(failure: Failure) -> Failure:
# defer.gatherResults and DeferredLists wrap failures.
# Deprecated: you probably just want to catch defer.FirstError and reraise
# the subFailure's value, which will do a better job of preserving stacktraces.
# (actually, you probably want to use yieldable_gather_results anyway)
failure.trap(defer.FirstError)
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations

Expand Down
54 changes: 32 additions & 22 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Hashable,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Expand All @@ -51,7 +52,7 @@
make_deferred_yieldable,
run_in_background,
)
from synapse.util import Clock, unwrapFirstError
from synapse.util import Clock

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -193,9 +194,9 @@ def __repr__(self) -> str:
T = TypeVar("T")


def concurrently_execute(
async def concurrently_execute(
func: Callable[[T], Any], args: Iterable[T], limit: int
) -> defer.Deferred:
) -> None:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.

Expand All @@ -221,20 +222,14 @@ async def _concurrently_execute_inner(value: T) -> None:
# We use `itertools.islice` to handle the case where the number of args is
# less than the limit, avoiding needlessly spawning unnecessary background
# tasks.
return make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(_concurrently_execute_inner, value)
for value in itertools.islice(it, limit)
],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
await yieldable_gather_results(
_concurrently_execute_inner, (value for value in itertools.islice(it, limit))
)


def yieldable_gather_results(
func: Callable, iter: Iterable, *args: Any, **kwargs: Any
) -> defer.Deferred:
async def yieldable_gather_results(
func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
) -> List[T]:
"""Executes the function with each argument concurrently.

Args:
Expand All @@ -245,15 +240,30 @@ def yieldable_gather_results(
**kwargs: Keyword arguments to be passed to each call to func

Returns
Deferred[list]: Resolved when all functions have been invoked, or errors if
one of the function calls fails.
A list containing the results of the function
"""
return make_deferred_yieldable(
defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter],
consumeErrors=True,
try:
return await make_deferred_yieldable(
defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter],
consumeErrors=True,
)
)
).addErrback(unwrapFirstError)
except defer.FirstError as dfe:
# unwrap the error from defer.gatherResults.

# The raised exception's traceback only includes func() etc if
# the 'await' happens before the exception is thrown - ie if the failure
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
# could be large.
#
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
# we could throw Twisted into the fires of Mordor.

# suppress exception chaining, because the FirstError doesn't tell us anything
# very interesting.
assert isinstance(dfe.subFailure.value, BaseException)
raise dfe.subFailure.value from None


T1 = TypeVar("T1")
Expand Down
115 changes: 113 additions & 2 deletions tests/util/test_async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback

from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.internet.task import Clock
from twisted.python.failure import Failure

from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
current_context,
)
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
timeout_deferred,
)

from tests.unittest import TestCase

Expand Down Expand Up @@ -171,3 +178,107 @@ def errback(res, deferred_name):
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)


class _TestException(Exception):
pass


class ConcurrentlyExecuteTest(TestCase):
def test_limits_runners(self):
"""If we have more tasks than runners, we should get the limit of runners"""
started = 0
waiters = []
processed = []

async def callback(v):
# when we first enter, bump the start count
nonlocal started
started += 1

# record the fact we got an item
processed.append(v)

# wait for the goahead before returning
d2 = Deferred()
waiters.append(d2)
await d2

# set it going
d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))

# check we got exactly 3 processes
self.assertEqual(started, 3)
self.assertEqual(len(waiters), 3)

# let one finish
waiters.pop().callback(0)

# ... which should start another
self.assertEqual(started, 4)
self.assertEqual(len(waiters), 3)

# we still shouldn't be done
self.assertNoResult(d2)

# finish the job
while waiters:
waiters.pop().callback(0)

# check everything got done
self.assertEqual(started, 5)
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2)

def test_preserves_stacktraces(self):
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
d1 = Deferred()

async def callback(v):
# alas, this doesn't work at all without an await here
await d1
raise _TestException("bah")

async def caller():
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
# we expect to see "caller", "concurrently_execute" and "callback".
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
self.assertEqual(tb[-1].name, "callback")
else:
self.fail("No exception thrown")

d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)

def test_preserves_stacktraces_on_preformed_failure(self):
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
d1 = Deferred()
f = Failure(_TestException("bah"))

async def callback(v):
# alas, this doesn't work at all without an await here
await d1
await defer.fail(f)

async def caller():
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
# we expect to see "caller", "concurrently_execute", "callback",
# and some magic from inside ensureDeferred that happens when .fail
# is called.
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
self.assertEqual(tb[-2].name, "callback")
else:
self.fail("No exception thrown")

d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)