Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert additional databases to async/await (#8199)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Sep 1, 2020
1 parent 5bf8e5f commit 54f8d73
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 137 deletions.
1 change: 1 addition & 0 deletions changelog.d/8199.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
50 changes: 28 additions & 22 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import calendar
import logging
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
Expand Down Expand Up @@ -294,24 +294,24 @@ def _get_active_presence(self, db_conn):

return [UserPresenceState(**row) for row in rows]

def count_daily_users(self):
async def count_daily_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_daily_users", self._count_users, yesterday
)

def count_monthly_users(self):
async def count_monthly_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 30 days.
Note this method is intended for phonehome metrics only and is different
from the mau figure in synapse.storage.monthly_active_users which,
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)

Expand All @@ -330,15 +330,15 @@ def _count_users(self, txn, time_from):
(count,) = txn.fetchone()
return count

def count_r30_users(self):
async def count_r30_users(self) -> Dict[str, int]:
"""
Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days ago
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days apart
Returns counts globaly for a given user as well as breaking
by platform
Returns:
A mapping of counts globally as well as broken out by platform.
"""

def _count_r30_users(txn):
Expand Down Expand Up @@ -411,7 +411,7 @@ def _count_r30_users(txn):

return results

return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)

def _get_start_of_day(self):
"""
Expand All @@ -421,7 +421,7 @@ def _get_start_of_day(self):
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000

def generate_user_daily_visits(self):
async def generate_user_daily_visits(self) -> None:
"""
Generates daily visit data for use in cohort/ retention analysis
"""
Expand Down Expand Up @@ -476,7 +476,7 @@ def _generate_user_daily_visits(txn):
# frequently
self._last_user_visit_update = now

return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)

Expand All @@ -500,22 +500,28 @@ async def get_users(self) -> List[Dict[str, Any]]:
desc="get_users",
)

def get_users_paginate(
self, start, limit, user_id=None, name=None, guests=True, deactivated=False
):
async def get_users_paginate(
self,
start: int,
limit: int,
user_id: Optional[str] = None,
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
) -> Tuple[List[Dict[str, Any]], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Args:
start (int): start number to begin the query from
limit (int): number of rows to retrieve
user_id (string): search for user_id. ignored if name is not None
name (string): search for local part of user_id or display name
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
start: start number to begin the query from
limit: number of rows to retrieve
user_id: search for user_id. ignored if name is not None
name: search for local part of user_id or display name
guests: whether to in include guest users
deactivated: whether to include deactivated users
Returns:
defer.Deferred: resolves to list[dict[str, Any]], int
A tuple of a list of mappings from user to information and a count of total users.
"""

def get_users_paginate_txn(txn):
Expand Down Expand Up @@ -558,7 +564,7 @@ def get_users_paginate_txn(txn):
users = self.db_pool.cursor_to_dict(txn)
return users, count

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_users_paginate_txn", get_users_paginate_txn
)

Expand Down
38 changes: 18 additions & 20 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,9 @@ async def _get_device_update_edus_by_remote(

return results

def _get_last_device_update_for_remote_user(
async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int
):
) -> int:
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
Expand All @@ -326,12 +326,16 @@ def f(txn):
rows = txn.fetchall()
return rows[0][0]

return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
return await self.db_pool.runInteraction(
"get_last_device_update_for_remote_user", f
)

def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
async def mark_as_sent_devices_by_remote(
self, destination: str, stream_id: int
) -> None:
"""Mark that updates have successfully been sent to the destination.
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
Expand Down Expand Up @@ -684,7 +688,7 @@ async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
desc="make_remote_user_device_cache_as_stale",
)

def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user.
"""

Expand All @@ -698,7 +702,7 @@ def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)

return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
Expand Down Expand Up @@ -959,9 +963,9 @@ async def update_device(
desc="update_device",
)

def update_remote_device_list_cache_entry(
async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
):
) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's
Expand All @@ -972,11 +976,8 @@ def update_remote_device_list_cache_entry(
device_id: ID of decivice being updated
content: new data on this device
stream_id: the version of the device list
Returns:
Deferred[None]
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
Expand Down Expand Up @@ -1028,9 +1029,9 @@ def _update_remote_device_list_cache_entry_txn(
lock=False,
)

def update_remote_device_list_cache(
async def update_remote_device_list_cache(
self, user_id: str, devices: List[dict], stream_id: int
):
) -> None:
"""Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's
Expand All @@ -1040,11 +1041,8 @@ def update_remote_device_list_cache(
user_id: User to update device list for
devices: list of device objects supplied over federation
stream_id: the version of the device list
Returns:
Deferred[None]
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
Expand All @@ -1054,7 +1052,7 @@ def update_remote_device_list_cache(

def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
):
) -> None:
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
Expand Down
48 changes: 28 additions & 20 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,20 +823,24 @@ def _fetch_event_rows(self, txn, event_ids):

return event_dict

def _maybe_redact_event_row(self, original_ev, redactions, event_map):
def _maybe_redact_event_row(
self,
original_ev: EventBase,
redactions: Iterable[str],
event_map: Dict[str, EventBase],
) -> Optional[EventBase]:
"""Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted
event.
Args:
original_ev (EventBase):
redactions (iterable[str]): list of event ids of potential redaction events
event_map (dict[str, EventBase]): other events which have been fetched, in
which we can look up the redaaction events. Map from event id to event.
original_ev: The original event.
redactions: list of event ids of potential redaction events
event_map: other events which have been fetched, in which we can
look up the redaaction events. Map from event id to event.
Returns:
Deferred[EventBase|None]: if the event should be redacted, a pruned
event object. Otherwise, None.
If the event should be redacted, a pruned event object. Otherwise, None.
"""
if original_ev.type == "m.room.create":
# we choose to ignore redactions of m.room.create events.
Expand Down Expand Up @@ -946,17 +950,17 @@ def _get_current_state_event_counts_txn(self, txn, room_id):
row = txn.fetchone()
return row[0] if row else 0

def get_current_state_event_counts(self, room_id):
async def get_current_state_event_counts(self, room_id: str) -> int:
"""
Gets the current number of state events in a room.
Args:
room_id (str)
room_id: The room ID to query.
Returns:
Deferred[int]
The current number of state events.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,
Expand Down Expand Up @@ -991,15 +995,17 @@ def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()

def get_all_new_forward_event_rows(self, last_id, current_id, limit):
async def get_all_new_forward_event_rows(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
"""Returns new events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns: Deferred[List[Tuple]]
Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
Expand All @@ -1020,18 +1026,20 @@ def get_all_new_forward_event_rows(txn):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)

def get_ex_outlier_stream_rows(self, last_id, current_id):
async def get_ex_outlier_stream_rows(
self, last_id: int, current_id: int
) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
Returns: Deferred[List[Tuple]]
Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
Expand All @@ -1054,7 +1062,7 @@ def get_ex_outlier_stream_rows_txn(txn):
txn.execute(sql, (last_id, current_id))
return txn.fetchall()

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)

Expand Down Expand Up @@ -1226,11 +1234,11 @@ async def get_event_ordering(self, event_id):

return (int(res["topological_ordering"]), int(res["stream_ordering"]))

def get_next_event_to_expire(self):
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
table, or None if there's no more event to expire.
Returns: Deferred[Optional[Tuple[str, int]]]
Returns:
A tuple containing the event ID as its first element and an expiry timestamp
as its second one, if there's at least one row in the event_expiry table.
None otherwise.
Expand All @@ -1246,6 +1254,6 @@ def get_next_event_to_expire_txn(txn):

return txn.fetchone()

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
Loading

0 comments on commit 54f8d73

Please sign in to comment.