From 4924d9278fac3e6cd2527392972177316d280f52 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 12 May 2022 07:49:27 +0200 Subject: [PATCH 01/15] Add some type hints to datastore --- mypy.ini | 2 - synapse/storage/databases/main/metrics.py | 67 +++---- synapse/storage/databases/main/push_rule.py | 177 +++++++++++-------- synapse/storage/databases/main/roommember.py | 110 ++++++++---- 4 files changed, 210 insertions(+), 146 deletions(-) diff --git a/mypy.ini b/mypy.ini index ba0de419f5ea..36e70ba1b1e3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,8 +28,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/event_federation.py - |synapse/storage/databases/main/push_rule.py - |synapse/storage/databases/main/roommember.py |synapse/storage/schema/ |tests/api/test_auth.py diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 1480a0f04829..4f3cb1f5e9fc 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -14,12 +14,16 @@ import calendar import logging import time -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, List, Tuple, cast from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) @@ -71,8 +75,8 @@ def __init__( self._last_user_visit_update = self._get_start_of_day() @wrap_as_background_process("read_forward_extremities") - async def _read_forward_extremities(self): - def fetch(txn): + async def _read_forward_extremities(self) -> None: + def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]: txn.execute( """ SELECT t1.c, t2.c @@ -85,7 +89,7 @@ def fetch(txn): ) t2 ON t1.room_id = t2.room_id """ ) - return txn.fetchall() + return cast(List[Tuple[int, int]], txn.fetchall()) res = await self.db_pool.runInteraction("read_forward_extremities", fetch) @@ -95,7 +99,7 @@ def fetch(txn): (x[0] - 1) * x[1] for x in res if x[1] ) - async def count_daily_e2ee_messages(self): + async def count_daily_e2ee_messages(self) -> int: """ Returns an estimate of the number of messages sent in the last day. @@ -103,20 +107,20 @@ async def count_daily_e2ee_messages(self): call to this function, it will return None. """ - def _count_messages(txn): + def _count_messages(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(*) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) - async def count_daily_sent_e2ee_messages(self): - def _count_messages(txn): + async def count_daily_sent_e2ee_messages(self) -> int: + def _count_messages(txn: LoggingTransaction) -> int: # This is good enough as if you have silly characters in your own # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname @@ -136,22 +140,22 @@ def _count_messages(txn): "count_daily_sent_e2ee_messages", _count_messages ) - async def count_daily_active_e2ee_rooms(self): - def _count(txn): + async def count_daily_active_e2ee_rooms(self) -> int: + def _count(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction( "count_daily_active_e2ee_rooms", _count ) - async def count_daily_messages(self): + async def count_daily_messages(self) -> int: """ Returns an estimate of the number of messages sent in the last day. @@ -159,20 +163,20 @@ async def count_daily_messages(self): call to this function, it will return None. """ - def _count_messages(txn): + def _count_messages(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(*) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_messages", _count_messages) - async def count_daily_sent_messages(self): - def _count_messages(txn): + async def count_daily_sent_messages(self) -> int: + def _count_messages(txn: LoggingTransaction) -> int: # This is good enough as if you have silly characters in your own # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname @@ -192,15 +196,15 @@ def _count_messages(txn): "count_daily_sent_messages", _count_messages ) - async def count_daily_active_rooms(self): - def _count(txn): + async def count_daily_active_rooms(self) -> int: + def _count(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_daily_active_rooms", _count) @@ -226,7 +230,7 @@ async def count_monthly_users(self) -> int: "count_monthly_users", self._count_users, thirty_days_ago ) - def _count_users(self, txn, time_from): + def _count_users(self, txn: LoggingTransaction, time_from: int) -> int: """ Returns number of users seen in the past time_from period """ @@ -238,7 +242,7 @@ def _count_users(self, txn, time_from): ) u """ txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count async def count_r30_users(self) -> Dict[str, int]: @@ -252,7 +256,7 @@ async def count_r30_users(self) -> Dict[str, int]: A mapping of counts globally as well as broken out by platform. """ - def _count_r30_users(txn): + def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]: thirty_days_in_secs = 86400 * 30 now = int(self._clock.time()) thirty_days_ago_in_secs = now - thirty_days_in_secs @@ -317,7 +321,7 @@ def _count_r30_users(txn): txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) results["all"] = count return results @@ -344,7 +348,7 @@ async def count_r30v2_users(self) -> Dict[str, int]: - "web" (any web application -- it's not possible to distinguish Element Web here) """ - def _count_r30v2_users(txn): + def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]: thirty_days_in_secs = 86400 * 30 now = int(self._clock.time()) sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs @@ -441,11 +445,8 @@ def _count_r30v2_users(txn): thirty_days_in_secs * 1000, ), ) - row = txn.fetchone() - if row is None: - results["all"] = 0 - else: - results["all"] = row[0] + (count,) = cast(Tuple[int], txn.fetchone()) + results["all"] = count return results @@ -453,7 +454,7 @@ def _count_r30v2_users(txn): "count_r30v2_users", _count_r30v2_users ) - def _get_start_of_day(self): + def _get_start_of_day(self) -> int: """ Returns millisecond unixtime for start of UTC day. """ @@ -467,7 +468,7 @@ async def generate_user_daily_visits(self) -> None: Generates daily visit data for use in cohort/ retention analysis """ - def _generate_user_daily_visits(txn): + def _generate_user_daily_visits(txn: LoggingTransaction) -> None: logger.info("Calling _generate_user_daily_visits") today_start = self._get_start_of_day() a_day_in_milliseconds = 24 * 60 * 60 * 1000 diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 4ed913e24879..cfa809278f74 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -14,14 +14,18 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore @@ -33,6 +37,7 @@ AbstractStreamIdTracker, StreamIdGenerator, ) +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -57,7 +62,11 @@ def _is_experimental_rule_enabled( return True -def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig): +def _load_rules( + rawrules: List[JsonDict], + enabled_map: Dict[str, bool], + experimental_config: ExperimentalConfig, +) -> List[JsonDict]: ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) @@ -137,7 +146,7 @@ def __init__( ) @abc.abstractmethod - def get_max_push_rules_stream_id(self): + def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. Returns: @@ -146,7 +155,7 @@ def get_max_push_rules_stream_id(self): raise NotImplementedError() @cached(max_entries=5000) - async def get_push_rules_for_user(self, user_id): + async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]: rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, @@ -168,7 +177,7 @@ async def get_push_rules_for_user(self, user_id): return _load_rules(rows, enabled_map, self.hs.config.experimental) @cached(max_entries=5000) - async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]: + async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, @@ -184,7 +193,7 @@ async def have_push_rules_changed_for_user( return False else: - def have_push_rules_changed_txn(txn): + def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool: sql = ( "SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id" @@ -202,11 +211,13 @@ def have_push_rules_changed_txn(txn): list_name="user_ids", num_args=1, ) - async def bulk_get_push_rules(self, user_ids): + async def bulk_get_push_rules( + self, user_ids: Collection[str] + ) -> Dict[str, List[JsonDict]]: if not user_ids: return {} - results = {user_id: [] for user_id in user_ids} + results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules", @@ -250,7 +261,7 @@ async def copy_push_rule_from_room_to_room( condition["pattern"] = new_room_id # Add the rule for the new room - await self.add_push_rule( + await self.add_push_rule( # type: ignore[attr-defined] user_id=user_id, rule_id=new_rule_id, priority_class=rule["priority_class"], @@ -286,11 +297,13 @@ async def copy_push_rules_from_room_to_room_for_user( list_name="user_ids", num_args=1, ) - async def bulk_get_push_rules_enabled(self, user_ids): + async def bulk_get_push_rules_enabled( + self, user_ids: Collection[str] + ) -> Dict[str, Dict[str, bool]]: if not user_ids: return {} - results = {user_id: {} for user_id in user_ids} + results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules_enable", @@ -331,7 +344,9 @@ async def get_all_push_rule_updates( if last_id == current_id: return [], current_id, False - def get_all_push_rule_updates_txn(txn): + def get_all_push_rule_updates_txn( + txn: LoggingTransaction + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, user_id FROM push_rules_stream @@ -340,7 +355,10 @@ def get_all_push_rule_updates_txn(txn): LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(stream_id, (user_id,)) for stream_id, user_id in txn] + updates = cast( + List[Tuple[int, tuple]], + [(stream_id, (user_id,)) for stream_id, user_id in txn], + ) limited = False upper_bound = current_id @@ -358,17 +376,17 @@ def get_all_push_rule_updates_txn(txn): class PushRuleStore(PushRulesWorkerStore): async def add_push_rule( self, - user_id, - rule_id, - priority_class, - conditions, - actions, - before=None, - after=None, + user_id: str, + rule_id: str, + priority_class: int, + conditions: List[Dict[str, str]], + actions: List[Union[dict, str]], + before: Optional[str] = None, + after: Optional[str] = None, ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - async with self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -400,17 +418,17 @@ async def add_push_rule( def _add_push_rule_relative_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + priority_class: int, + conditions_json: str, + actions_json: str, + before: str, + after: str, + ) -> None: # Lock the table since otherwise we'll have annoying races between the # SELECT here and the UPSERT below. self.database_engine.lock_table(txn, "push_rules") @@ -470,15 +488,15 @@ def _add_push_rule_relative_txn( def _add_push_rule_highest_priority_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + priority_class: int, + conditions_json: str, + actions_json: str, + ) -> None: # Lock the table since otherwise we'll have annoying races between the # SELECT here and the UPSERT below. self.database_engine.lock_table(txn, "push_rules") @@ -510,17 +528,17 @@ def _add_push_rule_highest_priority_txn( def _upsert_push_rule_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - conditions_json, - actions_json, - update_stream=True, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + priority_class: int, + priority: int, + conditions_json: str, + actions_json: str, + update_stream: bool = True, + ) -> None: """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -538,7 +556,7 @@ def _upsert_push_rule_txn( if txn.rowcount == 0: # We didn't update a row with the given rule_id so insert one - push_rule_id = self._push_rule_id_gen.get_next() + push_rule_id = self._push_rule_id_gen.get_next() # type: ignore[attr-defined] self.db_pool.simple_insert_txn( txn, @@ -586,7 +604,7 @@ def _upsert_push_rule_txn( else: raise RuntimeError("Unknown database engine") - new_enable_id = self._push_rules_enable_id_gen.get_next() + new_enable_id = self._push_rules_enable_id_gen.get_next() # type: ignore[attr-defined] txn.execute(sql, (new_enable_id, user_id, rule_id, 1)) async def delete_push_rule(self, user_id: str, rule_id: str) -> None: @@ -600,7 +618,11 @@ async def delete_push_rule(self, user_id: str, rule_id: str) -> None: rule_id: The rule_id of the rule to be deleted """ - def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + def delete_push_rule_txn( + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + ) -> None: # we don't use simple_delete_one_txn because that would fail if the # user did not have a push_rule_enable row. self.db_pool.simple_delete_txn( @@ -615,7 +637,7 @@ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - async with self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -646,7 +668,7 @@ async def set_push_rule_enabled( Raises: RuleNotFoundException if the rule does not exist. """ - async with self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", @@ -661,15 +683,15 @@ async def set_push_rule_enabled( def _set_push_rule_enabled_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - enabled, - is_default_rule, - ): - new_id = self._push_rules_enable_id_gen.get_next() + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + enabled: bool, + is_default_rule: bool, + ) -> None: + new_id = self._push_rules_enable_id_gen.get_next() # type: ignore[attr-defined] if not is_default_rule: # first check it exists; we need to lock for key share so that a @@ -740,7 +762,11 @@ async def set_push_rule_actions( """ actions_json = json_encoder.encode(actions) - def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): + def set_push_rule_actions_txn( + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + ) -> None: if is_default_rule: # Add a dummy rule to the rules table with the user specified # actions. @@ -783,7 +809,7 @@ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): data={"actions": actions_json}, ) - async with self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -794,8 +820,15 @@ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): ) def _insert_push_rules_update_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None - ): + self, + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + op: str, + data: Optional[JsonDict] = None, + ) -> None: values = { "stream_id": stream_id, "event_stream_ordering": event_stream_ordering, @@ -814,5 +847,5 @@ def _insert_push_rules_update_txn( self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_max_push_rules_stream_id(self): + def get_max_push_rules_stream_id(self) -> int: return self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 48e83592e728..551df44659e8 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -37,7 +37,12 @@ wrap_as_background_process, ) from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -46,7 +51,7 @@ ProfileInfo, RoomsForUser, ) -from synapse.types import PersistedEventPosition, get_domain_from_id +from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -115,7 +120,7 @@ def __init__( ) @wrap_as_background_process("_count_known_servers") - async def _count_known_servers(self): + async def _count_known_servers(self) -> int: """ Count the servers that this server knows about. @@ -123,7 +128,7 @@ async def _count_known_servers(self): `synapse_federation_known_servers` LaterGauge to collect. """ - def _transact(txn): + def _transact(txn: LoggingTransaction) -> int: if isinstance(self.database_engine, Sqlite3Engine): query = """ SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) @@ -150,7 +155,9 @@ def _transact(txn): self._known_servers_count = max([count, 1]) return self._known_servers_count - def _check_safe_current_state_events_membership_updated_txn(self, txn): + def _check_safe_current_state_events_membership_updated_txn( + self, txn: LoggingTransaction + ) -> None: """Checks if it is safe to assume the new current_state_events membership column is up to date """ @@ -182,7 +189,9 @@ async def get_users_in_room(self, room_id: str) -> List[str]: "get_users_in_room", self.get_users_in_room_txn, room_id ) - def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: + def get_users_in_room_txn( + self, txn: LoggingTransaction, room_id: str + ) -> List[str]: # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. @@ -222,7 +231,9 @@ async def get_users_in_room_with_profiles( A mapping from user ID to ProfileInfo. """ - def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]: + def _get_users_in_room_with_profiles( + txn: LoggingTransaction + ) -> Dict[str, ProfileInfo]: sql = """ SELECT state_key, display_name, avatar_url FROM room_memberships as m INNER JOIN current_state_events as c @@ -250,7 +261,9 @@ async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: dict of membership states, pointing to a MemberSummary named tuple. """ - def _get_room_summary_txn(txn): + def _get_room_summary_txn( + txn: LoggingTransaction, + ) -> Dict[str, MemberSummary]: # first get counts. # We do this all in one transaction to keep the cache small. # FIXME: get rid of this when we have room_stats @@ -279,7 +292,7 @@ def _get_room_summary_txn(txn): """ txn.execute(sql, (room_id,)) - res = {} + res: Dict[str, MemberSummary] = {} for count, membership in txn: res.setdefault(membership, MemberSummary([], count)) @@ -400,7 +413,7 @@ async def get_rooms_for_local_user_where_membership_is( def _get_rooms_for_local_user_where_membership_is_txn( self, - txn, + txn: LoggingTransaction, user_id: str, membership_list: List[str], ) -> List[RoomsForUser]: @@ -488,7 +501,7 @@ async def get_rooms_for_user_with_stream_ordering( ) def _get_rooms_for_user_with_stream_ordering_txn( - self, txn, user_id: str + self, txn: LoggingTransaction, user_id: str ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called @@ -542,7 +555,7 @@ async def get_rooms_for_users_with_stream_ordering( ) def _get_rooms_for_users_with_stream_ordering_txn( - self, txn, user_ids: Collection[str] + self, txn: LoggingTransaction, user_ids: Collection[str] ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: clause, args = make_in_list_sql_clause( @@ -575,7 +588,7 @@ def _get_rooms_for_users_with_stream_ordering_txn( txn.execute(sql, [Membership.JOIN] + args) - result = {user_id: set() for user_id in user_ids} + result: Dict[str, set] = {user_id: set() for user_id in user_ids} for user_id, room_id, instance, stream_id in txn: result[user_id].add( GetRoomsForUserWithStreamOrdering( @@ -595,7 +608,9 @@ async def get_users_server_still_shares_room_with( if not user_ids: return set() - def _get_users_server_still_shares_room_with_txn(txn): + def _get_users_server_still_shares_room_with_txn( + txn: LoggingTransaction, + ) -> Set[str]: sql = """ SELECT state_key FROM current_state_events WHERE @@ -657,7 +672,7 @@ async def get_users_who_share_room_with_user( async def get_joined_users_from_context( self, event: EventBase, context: EventContext ) -> Dict[str, ProfileInfo]: - state_group = context.state_group + state_group: Union[object, Optional[int]] = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -671,7 +686,7 @@ async def get_joined_users_from_context( ) async def get_joined_users_from_state( - self, room_id, state_entry + self, room_id: str, state_entry: "_StateCacheEntry" ) -> Dict[str, ProfileInfo]: state_group = state_entry.state_group if not state_group: @@ -689,12 +704,12 @@ async def get_joined_users_from_state( @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) async def _get_joined_users_from_context( self, - room_id, - state_group, - current_state_ids, - cache_context, - event=None, - context=None, + room_id: str, + state_group: int, + current_state_ids: StateMap[str], + cache_context: _CacheContext, + event: Optional[EventBase] = None, + context: Optional[EventContext] = None, ) -> Dict[str, ProfileInfo]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states @@ -765,14 +780,18 @@ async def _get_joined_users_from_context( return users_in_room @cached(max_entries=10000) - def _get_joined_profile_from_event_id(self, event_id): + def _get_joined_profile_from_event_id( + self, event_id: str + ) -> Optional[Tuple[str, ProfileInfo]]: raise NotImplementedError() @cachedList( cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids", ) - async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): + async def _get_joined_profiles_from_event_ids( + self, event_ids: Iterable[str] + ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. @@ -780,8 +799,7 @@ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): event_ids: The member event IDs to lookup Returns: - dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID - to `user_id` and ProfileInfo (or None if not join event). + Map from event ID to `user_id` and ProfileInfo (or None if not join event). """ rows = await self.db_pool.simple_select_many_batch( @@ -847,7 +865,9 @@ async def _check_host_room_membership( return True - async def get_joined_hosts(self, room_id: str, state_entry): + async def get_joined_hosts( + self, room_id: str, state_entry: "_StateCacheEntry" + ) -> FrozenSet[str]: state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -942,7 +962,7 @@ async def did_forget(self, user_id: str, room_id: str) -> bool: Returns False if they have since re-joined.""" - def f(txn): + def f(txn: LoggingTransaction) -> int: sql = ( "SELECT" " COUNT(*)" @@ -973,7 +993,7 @@ async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: The forgotten rooms. """ - def _get_forgotten_rooms_for_user_txn(txn): + def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]: # This is a slightly convoluted query that first looks up all rooms # that the user has forgotten in the past, then rechecks that list # to see if any have subsequently been updated. This is done so that @@ -1076,7 +1096,9 @@ async def is_local_host_in_room_ignoring_users( clause, ) - def _is_local_host_in_room_ignoring_users_txn(txn): + def _is_local_host_in_room_ignoring_users_txn( + txn: LoggingTransaction, + ) -> bool: txn.execute(sql, (room_id, Membership.JOIN, *args)) return bool(txn.fetchone()) @@ -1110,15 +1132,17 @@ def __init__( where_clause="forgotten = 1", ) - async def _background_add_membership_profile(self, progress, batch_size): + async def _background_add_membership_profile( + self, progress: dict, batch_size: int + ) -> int: target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start + "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] ) max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 + "max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined] ) - def add_membership_profile_txn(txn): + def add_membership_profile_txn(txn: LoggingTransaction) -> int: sql = """ SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events @@ -1182,13 +1206,17 @@ def add_membership_profile_txn(txn): return result - async def _background_current_state_membership(self, progress, batch_size): + async def _background_current_state_membership( + self, progress: dict, batch_size: int + ) -> int: """Update the new membership column on current_state_events. This works by iterating over all rooms in alphebetical order. """ - def _background_current_state_membership_txn(txn, last_processed_room): + def _background_current_state_membership_txn( + txn: LoggingTransaction, last_processed_room: str + ) -> Tuple[int, bool]: processed = 0 while processed < batch_size: txn.execute( @@ -1242,7 +1270,11 @@ def _background_current_state_membership_txn(txn, last_processed_room): return row_count -class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): +class RoomMemberStore( + RoomMemberWorkerStore, + RoomMemberBackgroundUpdateStore, + CacheInvalidationWorkerStore, +): def __init__( self, database: DatabasePool, @@ -1254,7 +1286,7 @@ def __init__( async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" - def f(txn): + def f(txn: LoggingTransaction) -> None: sql = ( "UPDATE" " room_memberships" @@ -1288,5 +1320,5 @@ class _JoinedHostsCache: # equal to anything else). state_group: Union[object, int] = attr.Factory(object) - def __len__(self): + def __len__(self) -> int: return sum(len(v) for v in self.hosts_to_joined_users.values()) From 467b13655c8f9006638ade4b6dc608262949c9c2 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 12 May 2022 07:51:50 +0200 Subject: [PATCH 02/15] newsfile --- changelog.d/12717.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12717.misc diff --git a/changelog.d/12717.misc b/changelog.d/12717.misc new file mode 100644 index 000000000000..e793d08e5e3f --- /dev/null +++ b/changelog.d/12717.misc @@ -0,0 +1 @@ +Add some type hints to datastore. \ No newline at end of file From 632ee7582ddac81a5ee91309fe5c3dfd33c439e4 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 12 May 2022 09:27:47 +0200 Subject: [PATCH 03/15] some fixes --- synapse/handlers/sync.py | 6 +++--- synapse/storage/databases/main/metrics.py | 4 ++-- synapse/storage/databases/main/push_rule.py | 4 ++-- synapse/storage/databases/main/roommember.py | 16 +++++++--------- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2c555a66d066..5d8437aa78bb 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -410,11 +410,11 @@ async def current_sync_for_user( set_tag(SynapseTags.SYNC_RESULT, bool(sync_result)) return sync_result - async def push_rules_for_user(self, user: UserID) -> JsonDict: + async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]: user_id = user.to_string() rules = await self.store.get_push_rules_for_user(user_id) - rules = format_push_rules_for_user(user, rules) - return rules + result = format_push_rules_for_user(user, rules) + return result async def ephemeral_by_room( self, diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 4f3cb1f5e9fc..f3eed03567bb 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -133,7 +133,7 @@ def _count_messages(txn: LoggingTransaction) -> int: """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction( @@ -189,7 +189,7 @@ def _count_messages(txn: LoggingTransaction) -> int: """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index cfa809278f74..b9262f8d2e7a 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -199,7 +199,7 @@ def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool: " WHERE user_id = ? AND ? < stream_id" ) txn.execute(sql, (user_id, last_id)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return bool(count) return await self.db_pool.runInteraction( @@ -345,7 +345,7 @@ async def get_all_push_rule_updates( return [], current_id, False def get_all_push_rule_updates_txn( - txn: LoggingTransaction + txn: LoggingTransaction, ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, user_id diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 551df44659e8..c6d3a2e18d80 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -51,7 +51,7 @@ ProfileInfo, RoomsForUser, ) -from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id +from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -189,9 +189,7 @@ async def get_users_in_room(self, room_id: str) -> List[str]: "get_users_in_room", self.get_users_in_room_txn, room_id ) - def get_users_in_room_txn( - self, txn: LoggingTransaction, room_id: str - ) -> List[str]: + def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]: # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. @@ -232,7 +230,7 @@ async def get_users_in_room_with_profiles( """ def _get_users_in_room_with_profiles( - txn: LoggingTransaction + txn: LoggingTransaction, ) -> Dict[str, ProfileInfo]: sql = """ SELECT state_key, display_name, avatar_url FROM room_memberships as m @@ -672,7 +670,7 @@ async def get_users_who_share_room_with_user( async def get_joined_users_from_context( self, event: EventBase, context: EventContext ) -> Dict[str, ProfileInfo]: - state_group: Union[object, Optional[int]] = context.state_group + state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -1271,9 +1269,9 @@ def _background_current_state_membership_txn( class RoomMemberStore( - RoomMemberWorkerStore, - RoomMemberBackgroundUpdateStore, - CacheInvalidationWorkerStore, + RoomMemberWorkerStore, + RoomMemberBackgroundUpdateStore, + CacheInvalidationWorkerStore, ): def __init__( self, From 2a43ae663aa237eeb7930721e76335a8b8d83a3b Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 12 May 2022 17:44:40 +0200 Subject: [PATCH 04/15] Fix import from `Cursor` --- synapse/storage/databases/main/metrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index ec2dec81eb67..14294a0bb85a 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -27,7 +27,6 @@ from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) -from synapse.storage.types import Cursor if TYPE_CHECKING: from synapse.server import HomeServer From e59ede6b48bf0d48b35b8ce26f624cb3e42d44a2 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 12 May 2022 17:49:10 +0200 Subject: [PATCH 05/15] Fix var in `rest/client/push_rule.py` --- synapse/rest/client/push_rule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index b98640b14ac5..8191b4e32c34 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -148,9 +148,9 @@ async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDic # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rules = await self.store.get_push_rules_for_user(user_id) + rules_raw = await self.store.get_push_rules_for_user(user_id) - rules = format_push_rules_for_user(requester.user, rules) + rules = format_push_rules_for_user(requester.user, rules_raw) path_parts = path.split("/")[1:] From 1ff2959c5265351d38a766b6a37bf24be7741715 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 12 May 2022 18:02:36 +0200 Subject: [PATCH 06/15] Change `state_group` to `Union[object, int]` --- synapse/storage/databases/main/roommember.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index c6d3a2e18d80..0d57ac82e95b 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -670,7 +670,7 @@ async def get_users_who_share_room_with_user( async def get_joined_users_from_context( self, event: EventBase, context: EventContext ) -> Dict[str, ProfileInfo]: - state_group = context.state_group + state_group: Union[object, int] = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -679,6 +679,8 @@ async def get_joined_users_from_context( state_group = object() current_state_ids = await context.get_current_state_ids() + assert current_state_ids is not None + assert state_group is not None return await self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) @@ -686,7 +688,7 @@ async def get_joined_users_from_context( async def get_joined_users_from_state( self, room_id: str, state_entry: "_StateCacheEntry" ) -> Dict[str, ProfileInfo]: - state_group = state_entry.state_group + state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -694,6 +696,7 @@ async def get_joined_users_from_state( # To do this we set the state_group to a new object as object() != object() state_group = object() + assert state_group is not None with Measure(self._clock, "get_joined_users_from_state"): return await self._get_joined_users_from_context( room_id, state_group, state_entry.state, context=state_entry @@ -703,11 +706,11 @@ async def get_joined_users_from_state( async def _get_joined_users_from_context( self, room_id: str, - state_group: int, + state_group: Union[object, int], current_state_ids: StateMap[str], cache_context: _CacheContext, event: Optional[EventBase] = None, - context: Optional[EventContext] = None, + context: Optional[Union[EventContext, "_StateCacheEntry"]] = None, ) -> Dict[str, ProfileInfo]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states @@ -866,7 +869,7 @@ async def _check_host_room_membership( async def get_joined_hosts( self, room_id: str, state_entry: "_StateCacheEntry" ) -> FrozenSet[str]: - state_group = state_entry.state_group + state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -874,6 +877,7 @@ async def get_joined_hosts( # To do this we set the state_group to a new object as object() != object() state_group = object() + assert state_group is not None with Measure(self._clock, "get_joined_hosts"): return await self._get_joined_hosts( room_id, state_group, state_entry=state_entry @@ -881,7 +885,7 @@ async def get_joined_hosts( @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, room_id: str, state_group: int, state_entry: "_StateCacheEntry" + self, room_id: str, state_group: Union[object, int], state_entry: "_StateCacheEntry" ) -> FrozenSet[str]: # We don't use `state_group`, it's there so that we can cache based on # it. However, its important that its never None, since two @@ -899,7 +903,7 @@ async def _get_joined_hosts( # `get_joined_hosts` is called with the "current" state group for the # room, and so consecutive calls will be for consecutive state groups # which point to the previous state group. - cache = await self._get_joined_hosts_cache(room_id) + cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc] # If the state group in the cache matches, we already have the data we need. if state_entry.state_group == cache.state_group: @@ -915,6 +919,7 @@ async def _get_joined_hosts( elif state_entry.prev_group == cache.state_group: # The cached work is for the previous state group, so we work out # the delta. + assert state_entry.delta_ids is not None for (typ, state_key), event_id in state_entry.delta_ids.items(): if typ != EventTypes.Member: continue From 2eddcafd9990a4363f2c342698be695af9f7af69 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 12 May 2022 18:04:56 +0200 Subject: [PATCH 07/15] Update `Set` -> `FrozenSet` in `state/__init__.py` --- synapse/state/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 54e41d537584..0219091c4e8b 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -239,13 +239,13 @@ async def get_current_users_in_room( entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) return await self.store.get_joined_users_from_state(room_id, entry) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]: event_ids = await self.store.get_latest_event_ids_in_room(room_id) return await self.get_hosts_in_room_at_events(room_id, event_ids) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] - ) -> Set[str]: + ) -> FrozenSet[str]: """Get the hosts that were in a room at the given event ids Args: From d1738112f58f4071a03cf1074a60d38914725293 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 12 May 2022 18:08:13 +0200 Subject: [PATCH 08/15] codestyle --- synapse/storage/databases/main/roommember.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 0d57ac82e95b..cc83433b9200 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -885,7 +885,10 @@ async def get_joined_hosts( @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, room_id: str, state_group: Union[object, int], state_entry: "_StateCacheEntry" + self, + room_id: str, + state_group: Union[object, int], + state_entry: "_StateCacheEntry", ) -> FrozenSet[str]: # We don't use `state_group`, it's there so that we can cache based on # it. However, its important that its never None, since two From 8b49b0f1ddba58f01317eadfc5d90cb7b20028de Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 12 May 2022 18:17:45 +0200 Subject: [PATCH 09/15] Revert changes in `state/__init__.py` --- synapse/state/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 0219091c4e8b..54e41d537584 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -239,13 +239,13 @@ async def get_current_users_in_room( entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) return await self.store.get_joined_users_from_state(room_id, entry) - async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]: + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: event_ids = await self.store.get_latest_event_ids_in_room(room_id) return await self.get_hosts_in_room_at_events(room_id, event_ids) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] - ) -> FrozenSet[str]: + ) -> Set[str]: """Get the hosts that were in a room at the given event ids Args: From 18e0f49c7c8f48f8c7c647bb4546e09c63d936e4 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sat, 14 May 2022 13:26:36 +0200 Subject: [PATCH 10/15] Change to `FrozenSet` --- synapse/federation/sender/__init__.py | 24 +++++++++++++++++------- synapse/handlers/sync.py | 6 +++--- synapse/state/__init__.py | 4 ++-- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 6d2f46318bea..dbe303ed9be8 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -15,7 +15,17 @@ import abc import logging from collections import OrderedDict -from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + Tuple, +) import attr from prometheus_client import Counter @@ -409,7 +419,7 @@ async def handle_event(event: EventBase) -> None: ) return - destinations: Optional[Set[str]] = None + destinations: Optional[Collection[str]] = None if not event.prev_event_ids(): # If there are no prev event IDs then the state is empty # and so no remote servers in the room @@ -444,7 +454,7 @@ async def handle_event(event: EventBase) -> None: ) return - destinations = { + sharded_destinations = { d for d in destinations if self._federation_shard_config.should_handle( @@ -456,12 +466,12 @@ async def handle_event(event: EventBase) -> None: # If we are sending the event on behalf of another server # then it already has the event and there is no reason to # send the event to it. - destinations.discard(send_on_behalf_of) + sharded_destinations.discard(send_on_behalf_of) - logger.debug("Sending %s to %r", event, destinations) + logger.debug("Sending %s to %r", event, sharded_destinations) - if destinations: - await self._send_pdu(event, destinations) + if sharded_destinations: + await self._send_pdu(event, sharded_destinations) now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5d8437aa78bb..792c73e330c6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -412,9 +412,9 @@ async def current_sync_for_user( async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]: user_id = user.to_string() - rules = await self.store.get_push_rules_for_user(user_id) - result = format_push_rules_for_user(user, rules) - return result + rules_raw = await self.store.get_push_rules_for_user(user_id) + rules = format_push_rules_for_user(user, rules_raw ) + return rules async def ephemeral_by_room( self, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 54e41d537584..0219091c4e8b 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -239,13 +239,13 @@ async def get_current_users_in_room( entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) return await self.store.get_joined_users_from_state(room_id, entry) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]: event_ids = await self.store.get_latest_event_ids_in_room(room_id) return await self.get_hosts_in_room_at_events(room_id, event_ids) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] - ) -> Set[str]: + ) -> FrozenSet[str]: """Get the hosts that were in a room at the given event ids Args: From 60828b99f603514dba34d25a28b2834ed4d93eb3 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sat, 14 May 2022 13:29:56 +0200 Subject: [PATCH 11/15] codestyle / spaces --- synapse/handlers/sync.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 792c73e330c6..d0eeb9b62e73 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -412,8 +412,8 @@ async def current_sync_for_user( async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]: user_id = user.to_string() - rules_raw = await self.store.get_push_rules_for_user(user_id) - rules = format_push_rules_for_user(user, rules_raw ) + rules_raw = await self.store.get_push_rules_for_user(user_id) + rules = format_push_rules_for_user(user, rules_raw) return rules async def ephemeral_by_room( From fa3112daef22266bf8a4a37c13a518c29b76eb54 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 16 May 2022 20:26:04 +0200 Subject: [PATCH 12/15] Review part 1 --- synapse/storage/databases/main/push_rule.py | 6 +++--- synapse/storage/databases/main/roommember.py | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index b9262f8d2e7a..1c1d484c4840 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -346,7 +346,7 @@ async def get_all_push_rule_updates( def get_all_push_rule_updates_txn( txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: sql = """ SELECT stream_id, user_id FROM push_rules_stream @@ -356,7 +356,7 @@ def get_all_push_rule_updates_txn( """ txn.execute(sql, (last_id, current_id, limit)) updates = cast( - List[Tuple[int, tuple]], + List[Tuple[int, Tuple[str]]], [(stream_id, (user_id,)) for stream_id, user_id in txn], ) @@ -380,7 +380,7 @@ async def add_push_rule( rule_id: str, priority_class: int, conditions: List[Dict[str, str]], - actions: List[Union[dict, str]], + actions: List[Union[JsonDict, str]], before: Optional[str] = None, after: Optional[str] = None, ) -> None: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index cc83433b9200..608d40dfa164 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -51,7 +51,7 @@ ProfileInfo, RoomsForUser, ) -from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id +from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -586,7 +586,9 @@ def _get_rooms_for_users_with_stream_ordering_txn( txn.execute(sql, [Membership.JOIN] + args) - result: Dict[str, set] = {user_id: set() for user_id in user_ids} + result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { + user_id: set() for user_id in user_ids + } for user_id, room_id, instance, stream_id in txn: result[user_id].add( GetRoomsForUserWithStreamOrdering( @@ -1139,7 +1141,7 @@ def __init__( ) async def _background_add_membership_profile( - self, progress: dict, batch_size: int + self, progress: JsonDict, batch_size: int ) -> int: target_min_stream_id = progress.get( "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] @@ -1213,7 +1215,7 @@ def add_membership_profile_txn(txn: LoggingTransaction) -> int: return result async def _background_current_state_membership( - self, progress: dict, batch_size: int + self, progress: JsonDict, batch_size: int ) -> int: """Update the new membership column on current_state_events. From 81e6db603af5c59300ea0ec9851fb02db81e3d0e Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 16 May 2022 20:37:55 +0200 Subject: [PATCH 13/15] mypy --- synapse/storage/databases/main/push_rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 1c1d484c4840..c7eae8192f10 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -319,7 +319,7 @@ async def bulk_get_push_rules_enabled( async def get_all_push_rule_updates( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: """Get updates for push_rules replication stream. Args: From 301e7c30ad76619e45a6ae07a9c1fe8a99a840b2 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 May 2022 08:11:58 +0200 Subject: [PATCH 14/15] add `AbstractStreamIdGenerator` --- synapse/storage/databases/main/push_rule.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index c7eae8192f10..d4dc2b53a6af 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -34,6 +34,7 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, AbstractStreamIdTracker, StreamIdGenerator, ) @@ -374,6 +375,10 @@ def get_all_push_rule_updates_txn( class PushRuleStore(PushRulesWorkerStore): + # Because we have write access, this will be a StreamIdGenerator + # (see PushRulesWorkerStore.__init__) + _push_rules_stream_id_gen: AbstractStreamIdGenerator + async def add_push_rule( self, user_id: str, @@ -386,7 +391,7 @@ async def add_push_rule( ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - async with self._push_rules_stream_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -637,7 +642,7 @@ def delete_push_rule_txn( txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - async with self._push_rules_stream_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -668,7 +673,7 @@ async def set_push_rule_enabled( Raises: RuleNotFoundException if the rule does not exist. """ - async with self._push_rules_stream_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", @@ -809,7 +814,7 @@ def set_push_rule_actions_txn( data={"actions": actions_json}, ) - async with self._push_rules_stream_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( From 694af7a8b3af363e020685cdeee3897903338494 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 May 2022 08:31:34 +0200 Subject: [PATCH 15/15] Move pusher `IdGenerator` --- synapse/storage/databases/main/__init__.py | 8 +------- synapse/storage/databases/main/push_rule.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 5895b892024c..d545a1c002c9 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -26,11 +26,7 @@ from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - IdGenerator, - MultiWriterIdGenerator, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -155,8 +151,6 @@ def __init__( ], ) - self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") - self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._group_updates_id_gen = StreamIdGenerator( db_conn, "local_group_updates", "stream_id" ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d4dc2b53a6af..0e2855fb446c 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -36,6 +36,7 @@ from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, AbstractStreamIdTracker, + IdGenerator, StreamIdGenerator, ) from synapse.types import JsonDict @@ -379,6 +380,17 @@ class PushRuleStore(PushRulesWorkerStore): # (see PushRulesWorkerStore.__init__) _push_rules_stream_id_gen: AbstractStreamIdGenerator + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + async def add_push_rule( self, user_id: str, @@ -561,7 +573,7 @@ def _upsert_push_rule_txn( if txn.rowcount == 0: # We didn't update a row with the given rule_id so insert one - push_rule_id = self._push_rule_id_gen.get_next() # type: ignore[attr-defined] + push_rule_id = self._push_rule_id_gen.get_next() self.db_pool.simple_insert_txn( txn, @@ -609,7 +621,7 @@ def _upsert_push_rule_txn( else: raise RuntimeError("Unknown database engine") - new_enable_id = self._push_rules_enable_id_gen.get_next() # type: ignore[attr-defined] + new_enable_id = self._push_rules_enable_id_gen.get_next() txn.execute(sql, (new_enable_id, user_id, rule_id, 1)) async def delete_push_rule(self, user_id: str, rule_id: str) -> None: @@ -696,7 +708,7 @@ def _set_push_rule_enabled_txn( enabled: bool, is_default_rule: bool, ) -> None: - new_id = self._push_rules_enable_id_gen.get_next() # type: ignore[attr-defined] + new_id = self._push_rules_enable_id_gen.get_next() if not is_default_rule: # first check it exists; we need to lock for key share so that a