Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse/storage/databases/main/events_bg_updates.py #11654

Merged
merged 1 commit into from
Dec 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11654.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
Expand Down Expand Up @@ -202,6 +201,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.event_push_actions]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.events_bg_updates]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True

Expand Down
69 changes: 40 additions & 29 deletions synapse/storage/databases/main/events_bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast

import attr

Expand Down Expand Up @@ -240,12 +240,14 @@ def __init__(

################################################################################

async def _background_reindex_fields_sender(self, progress, batch_size):
async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)

def reindex_txn(txn):
def reindex_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, json FROM events"
" INNER JOIN event_json USING (event_id)"
Expand Down Expand Up @@ -307,12 +309,14 @@ def reindex_txn(txn):

return result

async def _background_reindex_origin_server_ts(self, progress, batch_size):
async def _background_reindex_origin_server_ts(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)

def reindex_search_txn(txn):
def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
Expand Down Expand Up @@ -381,7 +385,9 @@ def reindex_search_txn(txn):

return result

async def _cleanup_extremities_bg_update(self, progress, batch_size):
async def _cleanup_extremities_bg_update(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to clean out extremities that should have been
deleted previously.

Expand All @@ -402,12 +408,12 @@ async def _cleanup_extremities_bg_update(self, progress, batch_size):
# have any descendants, but if they do then we should delete those
# extremities.

def _cleanup_extremities_bg_update_txn(txn):
def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
# The set of extremity event IDs that we're checking this round
original_set = set()

# A dict[str, set[str]] of event ID to their prev events.
graph = {}
# A dict[str, Set[str]] of event ID to their prev events.
graph: Dict[str, Set[str]] = {}

# The set of descendants of the original set that are not rejected
# nor soft-failed. Ancestors of these events should be removed
Expand Down Expand Up @@ -536,7 +542,7 @@ def _cleanup_extremities_bg_update_txn(txn):
room_ids = {row["room_id"] for row in rows}
for room_id in room_ids:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
)

self.db_pool.simple_delete_many_txn(
Expand All @@ -558,7 +564,7 @@ def _cleanup_extremities_bg_update_txn(txn):
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
)

def _drop_table_txn(txn):
def _drop_table_txn(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE _extremities_to_check")

await self.db_pool.runInteraction(
Expand All @@ -567,11 +573,11 @@ def _drop_table_txn(txn):

return num_handled

async def _redactions_received_ts(self, progress, batch_size):
async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
"""Handles filling out the `received_ts` column in redactions."""
last_event_id = progress.get("last_event_id", "")

def _redactions_received_ts_txn(txn):
def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
# Fetch the set of event IDs that we want to update
sql = """
SELECT event_id FROM redactions
Expand Down Expand Up @@ -622,10 +628,12 @@ def _redactions_received_ts_txn(txn):

return count

async def _event_fix_redactions_bytes(self, progress, batch_size):
async def _event_fix_redactions_bytes(
self, progress: JsonDict, batch_size: int
) -> int:
"""Undoes hex encoded censored redacted event JSON."""

def _event_fix_redactions_bytes_txn(txn):
def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
# This update is quite fast due to new index.
txn.execute(
"""
Expand All @@ -650,11 +658,11 @@ def _event_fix_redactions_bytes_txn(txn):

return 1

async def _event_store_labels(self, progress, batch_size):
async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")

def _event_store_labels_txn(txn):
def _event_store_labels_txn(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT event_id, json FROM event_json
Expand Down Expand Up @@ -754,7 +762,10 @@ def get_rejected_events(
),
)

return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
return cast(
List[Tuple[str, str, JsonDict, bool, bool]],
[(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
)

results = await self.db_pool.runInteraction(
desc="_rejected_events_metadata_get", func=get_rejected_events
Expand Down Expand Up @@ -912,7 +923,7 @@ async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:

def _calculate_chain_cover_txn(
self,
txn: Cursor,
txn: LoggingTransaction,
last_room_id: str,
last_depth: int,
last_stream: int,
Expand Down Expand Up @@ -1023,10 +1034,10 @@ def _calculate_chain_cover_txn(
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
self.event_chain_id_gen,
self.event_chain_id_gen, # type: ignore[attr-defined]
event_to_room_id,
event_to_types,
event_to_auth_chain,
cast(Dict[str, Sequence[str]], event_to_auth_chain),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not cast here:

synapse/storage/databases/main/events_bg_updates.py:1040: error: Argument 6 to "_add_chain_cover_index" of "PersistEventsStore" has incompatible type "Dict[str, List[str]]"; expected "Dict[str, Sequence[str]]"  [arg-type]
synapse/storage/databases/main/events_bg_updates.py:1040: note: "Dict" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance
synapse/storage/databases/main/events_bg_updates.py:1040: note: Consider using "Mapping" instead, which is covariant in the value type

Is there a better solution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure if there's a better solution, this seems like it might be fine? In this method the sequence is mutable, while in _add_chain_cover_index it doesn't ever need it to be mutable so we're telling mypy it can treat the second argument as non-mutable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of Python's type system is always a pain.
There's no nice solution here, as much as I always try to see one.
To get around it without having to cheat the system, you'd have to be able to 'consume' a Dict[str, List[str]] and turn it into Dict[str, Sequence[str]] such that you couldn't ever re-use the Dict[str, List[str]] (since it might contain e.g. tuples).

You could probably do that by cloning it, since the clone and the old one wouldn't be the same thing, but that seems like a waste of time/memory when the dict has already been built.

Realistically, you're probably better off leaving this cast in, as much as that makes me sad. With any luck there are some tests that check this part of the code that will ensure the types are right... :-)

)

return _CalculateChainCover(
Expand All @@ -1046,7 +1057,7 @@ async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> in
"""
current_event_id = progress.get("current_event_id", "")

def purged_chain_cover_txn(txn) -> int:
def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
# The event ID from events will be null if the chain ID / sequence
# number points to a purged event.
sql = """
Expand Down Expand Up @@ -1181,14 +1192,14 @@ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
# Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
self._invalidate_cache_and_stream(
txn, self.get_relations_for_event, cache_tuple
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream(
txn, self.get_aggregation_groups_for_event, cache_tuple
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream(
txn, self.get_thread_summary, cache_tuple
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
Comment on lines +1195 to +1202
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, if you redefined the class to have these stores as the ancestor, you wouldn't need these ignores:

class EventsBackgroundUpdatesStore(RelationsStore, CacheInvalidationWorkerStore):

However, that seems to mess things up here — sad, but something we're aware of and want to sort out one day :|.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been doing the same in my code, I figure the stores inheritance stuff is a separate problem to solve.

)

if results:
Expand Down Expand Up @@ -1220,7 +1231,7 @@ async def _background_populate_stream_ordering2(
"""
batch_size = max(batch_size, 1)

def process(txn: Cursor) -> int:
def process(txn: LoggingTransaction) -> int:
last_stream = progress.get("last_stream", -(1 << 31))
txn.execute(
"""
Expand Down