Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Type annotations in synapse.databases.main.devices #13025

Merged
merged 22 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13025.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to `synapse.storage.databases.main.devices`.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ exclude = (?x)
^(
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/schema/

|tests/api/test_auth.py
Expand Down
3 changes: 1 addition & 2 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore

if TYPE_CHECKING:
from synapse.server import HomeServer


class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
def __init__(
self,
database: DatabasePool,
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
self._min_stream_order_on_start = self.get_room_min_stream_ordering()

def get_device_stream_token(self) -> int:
# TODO: shouldn't this be moved to `DeviceWorkerStore`?
return self._device_list_id_gen.get_current_token()

async def get_users(self) -> List[JsonDict]:
Expand Down
58 changes: 37 additions & 21 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
cast,
)

from typing_extensions import Literal

from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
Expand All @@ -44,6 +46,8 @@
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -65,7 +69,7 @@
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"


class DeviceWorkerStore(SQLBaseStore):
class DeviceWorkerStore(EndToEndKeyWorkerStore):
def __init__(
self,
database: DatabasePool,
Expand All @@ -74,7 +78,9 @@ def __init__(
):
super().__init__(database, db_conn, hs)

device_list_max = self._device_list_id_gen.get_current_token()
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
Expand Down Expand Up @@ -339,8 +345,9 @@ async def get_device_updates_by_remote(
# following this stream later.
last_processed_stream_id = from_stream_id

query_map = {}
cross_signing_keys_by_user = {}
# A map of (user ID, device ID) to (stream ID, context).
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
for user_id, device_id, update_stream_id, update_context in updates:
# Calculate the remaining length budget.
# Note that, for now, each entry in `cross_signing_keys_by_user`
Expand Down Expand Up @@ -596,7 +603,7 @@ def _mark_as_sent_devices_by_remote_txn(
txn=txn,
table="device_lists_outbound_last_success",
key_names=("destination", "user_id"),
key_values=((destination, user_id) for user_id, _ in rows),
key_values=[destination, user_id) for user_id, _ in rows],
value_names=("stream_id",),
value_values=((stream_id,) for _, stream_id in rows),
)
Expand All @@ -621,7 +628,9 @@ async def add_user_signature_change_to_streams(
The new stream ID.
"""

async with self._device_list_id_gen.get_next() as stream_id:
# TODO: this looks like it's _writing_. Should this be on DeviceStore rather
# than DeviceWorkerStore?
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
Expand Down Expand Up @@ -686,7 +695,7 @@ async def get_user_devices_from_cache(
} - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache

results = {}
results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
Expand Down Expand Up @@ -727,7 +736,7 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
def get_cached_device_list_changes(
self,
from_key: int,
) -> Optional[Set[str]]:
) -> Optional[List[str]]:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
Expand All @@ -737,7 +746,7 @@ def get_cached_device_list_changes(
async def get_users_whose_devices_changed(
self,
from_key: int,
user_ids: Optional[Iterable[str]] = None,
user_ids: Optional[Collection[str]] = None,
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
Expand All @@ -757,6 +766,7 @@ async def get_users_whose_devices_changed(
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
user_ids_to_check: Optional[Collection[str]]
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
Expand All @@ -772,7 +782,7 @@ async def get_users_whose_devices_changed(
return set()

def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes = set()
changes: Set[str] = set()

stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
Expand All @@ -788,6 +798,9 @@ def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
"""

# Query device changes with a batch of users at a time
# Assertion for mypy's benefit; see also
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
Expand Down Expand Up @@ -854,7 +867,9 @@ async def get_all_device_list_changes_for_remotes(
if last_id == current_id:
return [], current_id, False

def _get_all_device_list_changes_for_remotes(txn):
def _get_all_device_list_changes_for_remotes(
txn: Cursor,
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
Expand Down Expand Up @@ -913,7 +928,7 @@ async def get_device_list_last_stream_id_for_remotes(
desc="get_device_list_last_stream_id_for_remotes",
)

results = {user_id: None for user_id in user_ids}
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})

return results
Expand Down Expand Up @@ -1346,9 +1361,9 @@ def __init__(

# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = LruCache(
cache_name="device_id_exists", max_size=10000
)
self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000)

async def store_device(
self,
Expand Down Expand Up @@ -1660,7 +1675,7 @@ def add_device_changes_txn(
context,
)

async with self._device_list_id_gen.get_next_mult(
async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
Expand Down Expand Up @@ -1800,7 +1815,7 @@ def _add_device_outbound_room_poke_txn(

async def get_uncoverted_outbound_room_pokes(
self, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
) -> List[Tuple[str, str, str, int, Dict[str, str]]]:
"""Get device list changes by room that have not yet been handled and
written to `device_lists_outbound_pokes`.

Expand All @@ -1818,7 +1833,7 @@ async def get_uncoverted_outbound_room_pokes(

def get_uncoverted_outbound_room_pokes_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
) -> List[Tuple[str, str, str, int, Dict[str, str]]]:
txn.execute(sql, (limit,))

return [
Expand All @@ -1827,7 +1842,7 @@ def get_uncoverted_outbound_room_pokes_txn(
device_id,
room_id,
stream_id,
db_to_json(opentracing_context),
Copy link
Contributor Author

@DMRobertson DMRobertson Jun 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure I liked this one. The original complaint was

synapse/storage/databases/main/devices.py:1860: error: Argument 2 to "runInteraction" of "DatabasePool" has incompatible type "Callable[[LoggingTransaction], List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]]"; expected "Callable[..., List[Tuple[str, str, str, int, Dict[str, str]]]]" [arg-type]

Which corresponds to the context argument of this function:

def _add_device_outbound_poke_to_stream_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
hosts: Collection[str],
stream_ids: List[int],
context: Dict[str, str],
) -> None:
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)
now = self._clock.time_msec()
stream_id_iterator = iter(stream_ids)
encoded_context = json_encoder.encode(context)
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
keys=(
"destination",
"stream_id",
"user_id",
"device_id",
"sent",
"ts",
"opentracing_context",
),
values=[
(
destination,
next(stream_id_iterator),
user_id,
device_id,
not self.hs.is_mine_id(
user_id
), # We only need to send out update for *our* users
now,
encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
)

If we pass context=None, we'll first transform it to a JSON null and store that blob of json as the four-codepoint string null. I don't think that's what we want!!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that opentracing_context is a nullable text field in the DB:

matrix=> \d device_lists_changes_in_room
Did not find any relation named "device_lists_changes_in_room".
matrix=> \d matrix.device_lists_changes_in_room
             Table "matrix.device_lists_changes_in_room"
          Column           |  Type   | Collation | Nullable | Default 
---------------------------+---------+-----------+----------+---------
 user_id                   | text    |           | not null | 
 device_id                 | text    |           | not null | 
 room_id                   | text    |           | not null | 
 stream_id                 | bigint  |           | not null | 
 converted_to_destinations | boolean |           | not null | 
 opentracing_context       | text    |           |          | 
Indexes:
    "device_lists_changes_in_stream_id" UNIQUE, btree (stream_id, room_id)
    "device_lists_changes_in_stream_id_unconverted" btree (stream_id) WHERE NOT converted_to_destinations

Copy link
Member

@clokep clokep Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if we have any "null" data already stored though? It is very unclear to me if it is safe to change this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not on matrix.org:

matrix=> select * from matrix.device_lists_changes_in_room where opentracing_context='null' limit 1;
(0 rows)

Time: 11322.423 ms (00:11.322)
matrix=> select * from matrix.device_lists_changes_in_room where opentracing_context='"null"' limit 1;
(0 rows)

Time: 10651.809 ms (00:10.652)
matrix=> select * from matrix.device_lists_changes_in_room where opentracing_context is NULL limit 1;
(0 rows)

Time: 3886.793 ms (00:03.887)

I suppose a safer approach would be to write a migration which makes the column non-nullable.

Maybe this is better dropped though; ideally type hints wouldn't kick off invasive changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1569b14 reverts this and 134d3da presents an alternative.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I'd be curious what @erikjohnston thinks here since this was was just added in #12321 (and maybe #13045). Is there a long term plan here we're unsure about?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused as to why we can't let it be None if I'm honest?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mainly just trying to keep the existing type hints happy. Some of them say Dict[] and some of them say Optional[Dict].

Maybe we just use Optional[Dict] everywhere?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd go with Optional[Dict] if the DB has a nullable column?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8395995 and 60e2ce0 do this.

db_to_json(opentracing_context) or {},
)
for user_id, device_id, room_id, stream_id, opentracing_context in txn
]
Expand All @@ -1843,7 +1858,7 @@ async def add_device_list_outbound_pokes(
room_id: str,
stream_id: int,
hosts: Collection[str],
context: Optional[Dict[str, str]],
context: Dict[str, str],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
Expand Down Expand Up @@ -1884,7 +1899,8 @@ def add_device_list_outbound_pokes_txn(
[],
)

async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
# is a StreamIdGenerator, or SlavedDataStore where it is a SlavedIdTracker.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
Expand Down