Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to synapse/storage/databases/main/stats.py (#11653)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel committed Dec 29, 2021
1 parent fcfe675 commit 15bb1c8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 42 deletions.
1 change: 1 addition & 0 deletions changelog.d/11653.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ exclude = (?x)
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
|synapse/storage/databases/main/stats.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/

Expand Down Expand Up @@ -214,6 +213,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.profile]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.stats]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True

Expand Down
94 changes: 53 additions & 41 deletions synapse/storage/databases/main/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
import logging
from enum import Enum
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast

from typing_extensions import Counter

from twisted.internet.defer import DeferredLock

from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
Expand Down Expand Up @@ -122,7 +126,9 @@ def __init__(
self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
self.db_pool.updates.register_noop_background_update("populate_stats_prepare")

async def _populate_stats_process_users(self, progress, batch_size):
async def _populate_stats_process_users(
self, progress: JsonDict, batch_size: int
) -> int:
"""
This is a background update which regenerates statistics for users.
"""
Expand All @@ -134,7 +140,7 @@ async def _populate_stats_process_users(self, progress, batch_size):

last_user_id = progress.get("last_user_id", "")

def _get_next_batch(txn):
def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT name FROM users
WHERE name > ?
Expand Down Expand Up @@ -168,7 +174,9 @@ def _get_next_batch(txn):

return len(users_to_work_on)

async def _populate_stats_process_rooms(self, progress, batch_size):
async def _populate_stats_process_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""This is a background update which regenerates statistics for rooms."""
if not self.stats_enabled:
await self.db_pool.updates._end_background_update(
Expand All @@ -178,7 +186,7 @@ async def _populate_stats_process_rooms(self, progress, batch_size):

last_room_id = progress.get("last_room_id", "")

def _get_next_batch(txn):
def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ?
Expand Down Expand Up @@ -307,7 +315,7 @@ async def bulk_update_stats_delta(
stream_id: Current position.
"""

def _bulk_update_stats_delta_txn(txn):
def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None:
for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items():
logger.debug(
Expand Down Expand Up @@ -339,7 +347,7 @@ async def update_stats_delta(
stats_type: str,
stats_id: str,
fields: Dict[str, int],
complete_with_stream_id: Optional[int],
complete_with_stream_id: int,
absolute_field_overrides: Optional[Dict[str, int]] = None,
) -> None:
"""
Expand Down Expand Up @@ -372,14 +380,14 @@ async def update_stats_delta(

def _update_stats_delta_txn(
self,
txn,
ts,
stats_type,
stats_id,
fields,
complete_with_stream_id,
absolute_field_overrides=None,
):
txn: LoggingTransaction,
ts: int,
stats_type: str,
stats_id: str,
fields: Dict[str, int],
complete_with_stream_id: int,
absolute_field_overrides: Optional[Dict[str, int]] = None,
) -> None:
if absolute_field_overrides is None:
absolute_field_overrides = {}

Expand Down Expand Up @@ -422,20 +430,23 @@ def _update_stats_delta_txn(
)

def _upsert_with_additive_relatives_txn(
self, txn, table, keyvalues, absolutes, additive_relatives
):
self,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
absolutes: Dict[str, Any],
additive_relatives: Dict[str, int],
) -> None:
"""Used to update values in the stats tables.
This is basically a slightly convoluted upsert that *adds* to any
existing rows.
Args:
txn
table (str): Table name
keyvalues (dict[str, any]): Row-identifying key values
absolutes (dict[str, any]): Absolute (set) fields
additive_relatives (dict[str, int]): Fields that will be added onto
if existing row present.
table: Table name
keyvalues: Row-identifying key values
absolutes: Absolute (set) fields
additive_relatives: Fields that will be added onto if existing row present.
"""
if self.database_engine.can_native_upsert:
absolute_updates = [
Expand Down Expand Up @@ -491,20 +502,17 @@ def _upsert_with_additive_relatives_txn(
current_row.update(absolutes)
self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)

async def _calculate_and_set_initial_state_for_room(
self, room_id: str
) -> Tuple[dict, dict, int]:
async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
"""Calculate and insert an entry into room_stats_current.
Args:
room_id: The room ID under calculation.
Returns:
A tuple of room state, membership counts and stream position.
"""

def _fetch_current_state_stats(txn):
pos = self.get_room_max_stream_ordering()
def _fetch_current_state_stats(
txn: LoggingTransaction,
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]

rows = self.db_pool.simple_select_many_txn(
txn,
Expand All @@ -524,7 +532,7 @@ def _fetch_current_state_stats(txn):
retcols=["event_id"],
)

event_ids = [row["event_id"] for row in rows]
event_ids = cast(List[str], [row["event_id"] for row in rows])

txn.execute(
"""
Expand All @@ -544,9 +552,9 @@ def _fetch_current_state_stats(txn):
(room_id,),
)

(current_state_events_count,) = txn.fetchone()
current_state_events_count = cast(Tuple[int], txn.fetchone())[0]

users_in_room = self.get_users_in_room_txn(txn, room_id)
users_in_room = self.get_users_in_room_txn(txn, room_id) # type: ignore[attr-defined]

return (
event_ids,
Expand All @@ -566,7 +574,7 @@ def _fetch_current_state_stats(txn):
"get_initial_state_for_room", _fetch_current_state_stats
)

state_event_map = await self.get_events(event_ids, get_prev_content=False)
state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]

room_state = {
"join_rules": None,
Expand Down Expand Up @@ -622,8 +630,10 @@ def _fetch_current_state_stats(txn):
},
)

async def _calculate_and_set_initial_state_for_user(self, user_id):
def _calculate_and_set_initial_state_for_user_txn(txn):
async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None:
def _calculate_and_set_initial_state_for_user_txn(
txn: LoggingTransaction,
) -> Tuple[int, int]:
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)

txn.execute(
Expand All @@ -634,7 +644,7 @@ def _calculate_and_set_initial_state_for_user_txn(txn):
""",
(user_id,),
)
(count,) = txn.fetchone()
count = cast(Tuple[int], txn.fetchone())[0]
return count, pos

joined_rooms, pos = await self.db_pool.runInteraction(
Expand Down Expand Up @@ -678,7 +688,9 @@ async def get_users_media_usage_paginate(
users that exist given this query
"""

def get_users_media_usage_paginate_txn(txn):
def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
filters = []
args = [self.hs.config.server.server_name]

Expand Down Expand Up @@ -733,7 +745,7 @@ def get_users_media_usage_paginate_txn(txn):
sql_base=sql_base,
)
txn.execute(sql, args)
count = txn.fetchone()[0]
count = cast(Tuple[int], txn.fetchone())[0]

sql = """
SELECT
Expand Down

0 comments on commit 15bb1c8

Please sign in to comment.