Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

When joining a remote room limit the number of events we concurrently check signatures/hashes for #10117

Merged
merged 11 commits into from
Jun 8, 2021
21 changes: 8 additions & 13 deletions synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,13 @@ async def _check_sigs_and_hash(
pdu: the event to be checked

Returns:
For each input event, a deferred which:
* returns the original event if the checks pass
* returns a redacted version of the event (if the signature
* the original event if the checks pass
* a redacted version of the event (if the signature
matched but the hash did not)
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel
"""
* throws a SynapseError if the signature check failed."""
try:
await _check_sigs_on_pdu(self.keyring, room_version, pdu)
except Exception as e:
except SynapseError as e:
logger.warning(
"Signature check failed for %s: %s",
pdu.event_id,
Expand Down Expand Up @@ -108,17 +105,15 @@ class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferr

async def _check_sigs_on_pdu(
keyring: Keyring, room_version: RoomVersion, pdu: EventBase
):
) -> None:
"""Check that the given events are correctly signed
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

Raise a SynapseError if the event wasn't correctly signed.

Args:
keyring: keyring object to do the checks
room_version: the room version of the PDUs
pdus: the events to be checked

Returns:
A Deferred for each event in pdus, which will either succeed if
the signatures are valid, or fail (with a SynapseError) if not.
"""

# we want to check that the event is signed by:
Expand Down Expand Up @@ -165,7 +160,7 @@ async def _check_sigs_on_pdu(
# (ie, the room version uses old-style non-hash event IDs).
if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id(
pdu.event_id
):
) != get_domain_from_id(pdu.sender):
try:
await keyring.verify_event_for_server(
get_domain_from_id(pdu.event_id),
Expand Down
35 changes: 20 additions & 15 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from synapse.federation.transport.client import SendJoinResponse
from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute, yieldable_gather_results
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination

Expand Down Expand Up @@ -360,7 +360,6 @@ async def _check_sigs_and_hash_and_fetch(
pdus: Collection[EventBase],
room_version: RoomVersion,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
Expand All @@ -377,24 +376,30 @@ async def _check_sigs_and_hash_and_fetch(
pdu
room_version
outlier: Whether the events are outliers or not
include_none: Whether to include None in the returned list
for events that have failed their checks

Returns:
A list of PDUs that have valid signatures and hashes.
"""
valid_pdus = await yieldable_gather_results(
self._check_sigs_and_hash_and_fetch_one,
pdus,
origin=origin,
room_version=room_version,
outlier=outlier,
)

if include_none:
return valid_pdus
else:
return [p for p in valid_pdus if p]
# We limit how many PDUs we check at once, as if we try to do hundreds
# of thousands of PDUs at once we see large memory spikes.

valid_pdus = []

async def _execute(pdu: EventBase) -> None:
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
pdu=pdu,
origin=origin,
outlier=outlier,
room_version=room_version,
)

if valid_pdu:
valid_pdus.append(valid_pdu)

await concurrently_execute(_execute, pdus, 10000)

return valid_pdus

async def _check_sigs_and_hash_and_fetch_one(
self,
Expand Down
19 changes: 14 additions & 5 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,11 @@ def __repr__(self) -> str:
)


T = TypeVar("T")


def concurrently_execute(
func: Callable, args: Iterable[Any], limit: int
func: Callable[[T], Any], args: Iterable[T], limit: int
) -> defer.Deferred:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.
Expand All @@ -173,20 +176,26 @@ def concurrently_execute(
limit: Maximum number of conccurent executions.

Returns:
Deferred[list]: Resolved when all function invocations have finished.
Deferred: Resolved when all function invocations have finished.
"""
it = iter(args)

async def _concurrently_execute_inner():
async def _concurrently_execute_inner(value: T) -> None:
try:
while True:
await maybe_awaitable(func(next(it)))
await maybe_awaitable(func(value))
value = next(it)
except StopIteration:
pass

# We use `zip` to handle the case where the number of args is less than the
# limit, avoiding needlessly spawning unnecessary background tasks.
return make_deferred_yieldable(
defer.gatherResults(
[run_in_background(_concurrently_execute_inner) for _ in range(limit)],
[
run_in_background(_concurrently_execute_inner, value)
for value, _ in zip(it, range(limit))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure zip is guaranteed not to drop values when one of the iterators runs out of inputs?

https://docs.python.org/3/library/functions.html#zip says:

zip() should only be used with unequal length inputs when you don’t care about trailing, unmatched values from the longer iterables.

indeed:

>>> a = [1,2,3,4]
>>> b = [1,2]
>>> 
>>> a_iter = iter(a)
>>> b_iter = iter(b)
>>> list(zip(a_iter, b_iter))
[(1, 1), (2, 2)]
>>> next(a_iter)
4

3 went missing :/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ugh. Good spot.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I wonder if:

Returns an iterator of tuples, where the i-th tuple contains the i-th element from each of the argument sequences or iterables.

means that its safe if we do zip(range(n), it)? We probably just want to do it in a different way...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, using itertools.islice(it, limit) is probably the right way of doing this, and cleaner...

],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
Expand Down