diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f40b071a7410..d58096f447c2 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1031,6 +1031,12 @@ async def _get_state_ids_after_missing_prev_event( InvalidResponseError: if the remote homeserver's response contains fields of the wrong type. """ + + # It would be better if we could query the difference from our known + # state to the given `event_id` so the sending server doesn't have to + # send as much and we don't have to process so many events. For example + # in a room like #matrixhq, we get 200k events (77k state_events, 122k + # auth_events) from this and just the `have_seen_events` takes 20s. ( state_event_ids, auth_event_ids, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8a7cdb024d60..4b3f958b1943 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -54,7 +54,13 @@ current_context, make_deferred_yieldable, ) -from synapse.logging.opentracing import start_active_span, tag_args, trace +from synapse.logging.opentracing import ( + SynapseTags, + set_tag, + start_active_span, + tag_args, + trace, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -1449,7 +1455,7 @@ async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]: @trace @tag_args async def have_seen_events( - self, room_id: str, event_ids: Iterable[str] + self, room_id: str, event_ids: Collection[str] ) -> Set[str]: """Given a list of event ids, check if we have already processed them. @@ -1462,44 +1468,43 @@ async def have_seen_events( event_ids: events we are looking for Returns: - The set of events we have already seen. + The remaining set of events we haven't seen. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) # @cachedList chomps lots of memory if you call it with a big list, so # we break it down. However, each batch requires its own index scan, so we make # the batches as big as possible. - results: Set[str] = set() - for chunk in batch_iter(event_ids, 500): - r = await self._have_seen_events_dict( - [(room_id, event_id) for event_id in chunk] - ) - results.update(eid for ((_rid, eid), have_event) in r.items() if have_event) + remaining_event_ids: Set[str] = set() + for chunk in batch_iter(event_ids, 1000): + remaining_event_ids_from_chunk = await self._have_seen_events_dict(chunk) + remaining_event_ids.update(remaining_event_ids_from_chunk) - return results + return remaining_event_ids - @cachedList(cached_method_name="have_seen_event", list_name="keys") - async def _have_seen_events_dict( - self, keys: Collection[Tuple[str, str]] - ) -> Dict[Tuple[str, str], bool]: + # @cachedList(cached_method_name="have_seen_event", list_name="event_ids") + async def _have_seen_events_dict(self, event_ids: Collection[str]) -> set[str]: """Helper for have_seen_events Returns: - a dict {(room_id, event_id)-> bool} + The remaining set of events we haven't seen. """ - # if the event cache contains the event, obviously we've seen it. - cache_results = { - (rid, eid) - for (rid, eid) in keys - if await self._get_event_cache.contains((eid,)) + # if the event cache contains the event, obviously we've seen it. + event_cache_entry_map = self._get_events_from_local_cache(event_ids) + event_ids_in_cache = event_cache_entry_map.keys() + remaining_event_ids = { + event_id for event_id in event_ids if event_id not in event_ids_in_cache } - results = dict.fromkeys(cache_results, True) - remaining = [k for k in keys if k not in cache_results] - if not remaining: - return results + if not remaining_event_ids: + return set() def have_seen_events_txn(txn: LoggingTransaction) -> None: + nonlocal remaining_event_ids # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1507,23 +1512,24 @@ def have_seen_events_txn(txn: LoggingTransaction) -> None: sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining] + txn.database_engine, "e.event_id", remaining_event_ids ) txn.execute(sql + clause, args) - found_events = {eid for eid, in txn} + found_event_ids = {eid for eid, in txn} - # ... and then we can update the results for each key - results.update( - {(rid, eid): (eid in found_events) for (rid, eid) in remaining} - ) + remaining_event_ids = { + event_id + for event_id in remaining_event_ids + if event_id not in found_event_ids + } await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) - return results + return remaining_event_ids @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: - res = await self._have_seen_events_dict(((room_id, event_id),)) - return res[(room_id, event_id)] + remaining_event_ids = await self._have_seen_events_dict({event_id}) + return event_id not in remaining_event_ids def _get_current_state_event_counts_txn( self, txn: LoggingTransaction, room_id: str diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 46d829b062a0..2b12038dd677 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import json from contextlib import contextmanager from typing import Generator, List, Tuple @@ -36,6 +37,8 @@ from tests import unittest +logger = logging.getLogger(__name__) + class HaveSeenEventsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): @@ -91,6 +94,71 @@ def prepare(self, reactor, clock, hs): ) self.event_ids.append(event_id) + # TODO: Remove me before merging + def test_benchmark(self): + import time + + room_id = "room123" + event_ids = [] + setup_start_time = time.time() + with LoggingContext(name="test-setup") as ctx: + for i in range(50000): + event_json = {"type": f"test {i}", "room_id": room_id} + event = make_event_from_dict(event_json, room_version=RoomVersions.V4) + event_id = event.event_id + + event_ids.append(event_id) + + self.get_success( + self.store.db_pool.simple_insert( + "events", + { + "event_id": event_id, + "room_id": room_id, + "topological_ordering": i, + "stream_ordering": 123 + i, + "type": event.type, + "processed": True, + "outlier": False, + }, + ) + ) + + setup_end_time = time.time() + logger.info( + "Setup time: %s", + (setup_end_time - setup_start_time), + ) + + with LoggingContext(name="test") as ctx: + + def time_have_seen_events(test_prefix: str, event_ids): + benchmark_start_time = time.time() + remaining_event_ids = self.get_success( + self.store.have_seen_events(room_id, event_ids) + ) + benchmark_end_time = time.time() + logger.info( + "Benchmark time (%s): %s", + "{test_prefix: <13}".format(test_prefix=test_prefix), + (benchmark_end_time - benchmark_start_time), + ) + self.assertIsNotNone(remaining_event_ids) + + event_ids_odd = event_ids[::2] + event_ids_even = event_ids[1::2] + + time_have_seen_events("1 cold cache", event_ids) + time_have_seen_events("2, warm cache", event_ids) + time_have_seen_events("3, warm cache", event_ids) + time_have_seen_events("4, odds", event_ids_odd) + time_have_seen_events("5, odds", event_ids_odd) + time_have_seen_events("6, evens", event_ids_even) + time_have_seen_events("7, evens", event_ids_even) + + # that should result in a many db queries + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + def test_simple(self): with LoggingContext(name="test") as ctx: res = self.get_success(