Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse/storage/databases/main/transactions.py #11589

Merged
merged 3 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11589.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ exclude = (?x)
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
|synapse/storage/databases/main/stats.py
|synapse/storage/databases/main/transactions.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/

Expand Down Expand Up @@ -216,6 +215,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.transactions]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True

Expand Down
49 changes: 25 additions & 24 deletions synapse/storage/databases/main/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.

import logging
from collections import namedtuple
from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast

import attr
from canonicaljson import encode_canonical_json
Expand All @@ -39,16 +38,6 @@
logger = logging.getLogger(__name__)


_TransactionRow = namedtuple(
"_TransactionRow",
("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
)

_UpdateTransactionRow = namedtuple(
"_TransactionRow", ("response_code", "response_json")
)


class DestinationSortOrder(Enum):
"""Enum to define the sorting method used when returning destinations."""

Expand Down Expand Up @@ -91,7 +80,7 @@ async def _cleanup_transactions(self) -> None:
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000

def _cleanup_transactions_txn(txn):
def _cleanup_transactions_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))

await self.db_pool.runInteraction(
Expand Down Expand Up @@ -121,7 +110,9 @@ async def get_received_txn_response(
origin,
)

def _get_received_txn_response(self, txn, transaction_id, origin):
def _get_received_txn_response(
self, txn: LoggingTransaction, transaction_id: str, origin: str
) -> Optional[Tuple[int, JsonDict]]:
result = self.db_pool.simple_select_one_txn(
txn,
table="received_transactions",
Expand Down Expand Up @@ -196,7 +187,7 @@ async def get_destination_retry_timings(
return result

def _get_destination_retry_timings(
self, txn, destination: str
self, txn: LoggingTransaction, destination: str
) -> Optional[DestinationRetryTimings]:
result = self.db_pool.simple_select_one_txn(
txn,
Expand Down Expand Up @@ -231,7 +222,7 @@ async def set_destination_retry_timings(
"""

if self.database_engine.can_native_upsert:
return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_native,
destination,
Expand All @@ -241,7 +232,7 @@ async def set_destination_retry_timings(
db_autocommit=True, # Safe as its a single upsert
)
else:
return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_emulated,
destination,
Expand All @@ -251,8 +242,13 @@ async def set_destination_retry_timings(
)

def _set_destination_retry_timings_native(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
self,
txn: LoggingTransaction,
destination: str,
failure_ts: Optional[int],
retry_last_ts: int,
retry_interval: int,
) -> None:
assert self.database_engine.can_native_upsert

# Upsert retry time interval if retry_interval is zero (i.e. we're
Expand Down Expand Up @@ -282,8 +278,13 @@ def _set_destination_retry_timings_native(
)

def _set_destination_retry_timings_emulated(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
self,
txn: LoggingTransaction,
destination: str,
failure_ts: Optional[int],
retry_last_ts: int,
retry_interval: int,
) -> None:
self.database_engine.lock_table(txn, "destinations")

# We need to be careful here as the data may have changed from under us
Expand Down Expand Up @@ -393,7 +394,7 @@ async def set_destination_last_successful_stream_ordering(
last_successful_stream_ordering: the stream_ordering of the most
recent successfully-sent PDU
"""
return await self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
"destinations",
keyvalues={"destination": destination},
values={"last_successful_stream_ordering": last_successful_stream_ordering},
Expand Down Expand Up @@ -534,7 +535,7 @@ def get_destinations_paginate_txn(
else:
order = "ASC"

args = []
args: List[object] = []
where_statement = ""
if destination:
args.extend(["%" + destination.lower() + "%"])
Expand All @@ -543,7 +544,7 @@ def get_destinations_paginate_txn(
sql_base = f"FROM destinations {where_statement} "
sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
txn.execute(sql, args)
count = txn.fetchone()[0]
count = cast(Tuple[int], txn.fetchone())[0]

sql = f"""
SELECT destination, retry_last_ts, retry_interval, failure_ts,
Expand Down