Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add basic read/write lock #15782

Merged
merged 9 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 170 additions & 56 deletions synapse/storage/databases/main/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.util import Clock
from synapse.util.stringutils import random_string

Expand Down Expand Up @@ -68,12 +69,18 @@ def __init__(
self._reactor = hs.get_reactor()
self._instance_name = hs.get_instance_id()

# A map from `(lock_name, lock_key)` to the token of any locks that we
# think we currently hold.
self._live_tokens: WeakValueDictionary[
# A map from `(lock_name, lock_key)` to lock that we think we
# currently hold.
self._live_lock_tokens: WeakValueDictionary[
Tuple[str, str], Lock
] = WeakValueDictionary()

# A map from `(lock_name, lock_key, token)` to read/write lock that we
# think we currently hold.
self._live_read_write_lock_tokens: WeakValueDictionary[
Tuple[str, str, str], Lock
] = WeakValueDictionary()

# When we shut down we want to remove the locks. Technically this can
# lead to a race, as we may drop the lock while we are still processing.
# However, a) it should be a small window, b) the lock is best effort
Expand All @@ -91,11 +98,13 @@ async def _on_shutdown(self) -> None:
"""Called when the server is shutting down"""
logger.info("Dropping held locks due to shutdown")

# We need to take a copy of the tokens dict as dropping the locks will
# cause the dictionary to change.
locks = dict(self._live_tokens)
# We need to take a copy of the locks as dropping the locks will cause
# the dictionary to change.
locks = list(self._live_lock_tokens.values()) + list(
self._live_read_write_lock_tokens.values()
)

for lock in locks.values():
for lock in locks:
await lock.release()

logger.info("Dropped locks due to shutdown")
Expand All @@ -122,7 +131,7 @@ async def _try_acquire_lock(
"""

# Check if this process has taken out a lock and if it's still valid.
lock = self._live_tokens.get((lock_name, lock_key))
lock = self._live_lock_tokens.get((lock_name, lock_key))
if lock and await lock.is_still_valid():
return None

Expand Down Expand Up @@ -176,61 +185,115 @@ def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
self._reactor,
self._clock,
self,
read_write=False,
lock_name=lock_name,
lock_key=lock_key,
token=token,
)

self._live_tokens[(lock_name, lock_key)] = lock
self._live_lock_tokens[(lock_name, lock_key)] = lock

return lock

async def _is_lock_still_valid(
self, lock_name: str, lock_key: str, token: str
) -> bool:
"""Checks whether this instance still holds the lock."""
last_renewed_ts = await self.db_pool.simple_select_one_onecol(
table="worker_locks",
keyvalues={
"lock_name": lock_name,
"lock_key": lock_key,
"token": token,
},
retcol="last_renewed_ts",
allow_none=True,
desc="is_lock_still_valid",
)
return (
last_renewed_ts is not None
and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts
)
async def try_acquire_read_write_lock(
self,
lock_name: str,
lock_key: str,
write: bool,
) -> Optional["Lock"]:
"""Try to acquire a lock for the given name/key. Will return an async
context manager if the lock is successfully acquired, which *must* be
used (otherwise the lock will leak).
"""

async def _renew_lock(self, lock_name: str, lock_key: str, token: str) -> None:
"""Attempt to renew the lock if we still hold it."""
await self.db_pool.simple_update(
table="worker_locks",
keyvalues={
"lock_name": lock_name,
"lock_key": lock_key,
"token": token,
},
updatevalues={"last_renewed_ts": self._clock.time_msec()},
desc="renew_lock",
)
now = self._clock.time_msec()
token = random_string(6)

async def _drop_lock(self, lock_name: str, lock_key: str, token: str) -> None:
"""Attempt to drop the lock, if we still hold it"""
await self.db_pool.simple_delete(
table="worker_locks",
keyvalues={
"lock_name": lock_name,
"lock_key": lock_key,
"token": token,
},
desc="drop_lock",
def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
# We attempt to acquire the lock by inserting into
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
#
# Before that though we clear the table of any stale locks.

delete_sql = """
DELETE FROM worker_read_write_locks
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
"""

insert_sql = """
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (lock_name, lock_key, token)
DO UPDATE
SET
last_renewed_ts = EXCLUDED.last_renewed_ts
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
"""

if isinstance(self.database_engine, PostgresEngine):
# For Postgres we can send these queries at the same time.
txn.execute(
delete_sql + ";" + insert_sql,
(
# DELETE args
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
# UPSERT args
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
else:
# For SQLite these need to be two queries.
txn.execute(
delete_sql,
(
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
),
)
txn.execute(
insert_sql,
(
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)

return

try:
await self.db_pool.runInteraction(
"try_acquire_read_write_lock",
_try_acquire_read_write_lock_txn,
)
except self.database_engine.module.IntegrityError:
return None

lock = Lock(
self._reactor,
self._clock,
self,
read_write=True,
lock_name=lock_name,
lock_key=lock_key,
token=token,
)

self._live_tokens.pop((lock_name, lock_key), None)
self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock

return lock


class Lock:
Expand Down Expand Up @@ -259,20 +322,31 @@ def __init__(
reactor: IReactorCore,
clock: Clock,
store: LockStore,
read_write: bool,
lock_name: str,
lock_key: str,
token: str,
) -> None:
self._reactor = reactor
self._clock = clock
self._store = store
self._read_write = read_write
self._lock_name = lock_name
self._lock_key = lock_key

self._token = token

self._table = "worker_read_write_locks" if read_write else "worker_locks"

self._looping_call = clock.looping_call(
self._renew, _RENEWAL_INTERVAL_MS, store, lock_name, lock_key, token
self._renew,
_RENEWAL_INTERVAL_MS,
store,
clock,
read_write,
lock_name,
lock_key,
token,
)

self._dropped = False
Expand All @@ -281,6 +355,8 @@ def __init__(
@wrap_as_background_process("Lock._renew")
async def _renew(
store: LockStore,
clock: Clock,
read_write: bool,
lock_name: str,
lock_key: str,
token: str,
Expand All @@ -291,12 +367,34 @@ async def _renew(
don't end up with a reference to `self` in the reactor, which would stop
this from being cleaned up if we dropped the context manager.
"""
await store._renew_lock(lock_name, lock_key, token)
table = "worker_read_write_locks" if read_write else "worker_locks"
await store.db_pool.simple_update(
table=table,
keyvalues={
"lock_name": lock_name,
"lock_key": lock_key,
"token": token,
},
updatevalues={"last_renewed_ts": clock.time_msec()},
desc="renew_lock",
)

async def is_still_valid(self) -> bool:
"""Check if the lock is still held by us"""
return await self._store._is_lock_still_valid(
self._lock_name, self._lock_key, self._token
last_renewed_ts = await self._store.db_pool.simple_select_one_onecol(
table=self._table,
keyvalues={
"lock_name": self._lock_name,
"lock_key": self._lock_key,
"token": self._token,
},
retcol="last_renewed_ts",
allow_none=True,
desc="is_lock_still_valid",
)
return (
last_renewed_ts is not None
and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts
)

async def __aenter__(self) -> None:
Expand Down Expand Up @@ -325,7 +423,23 @@ async def release(self) -> None:
if self._looping_call.running:
self._looping_call.stop()

await self._store._drop_lock(self._lock_name, self._lock_key, self._token)
await self._store.db_pool.simple_delete(
table=self._table,
keyvalues={
"lock_name": self._lock_name,
"lock_key": self._lock_key,
"token": self._token,
},
desc="drop_lock",
)

if self._read_write:
self._store._live_read_write_lock_tokens.pop(
(self._lock_name, self._lock_key, self._token), None
)
else:
self._store._live_lock_tokens.pop((self._lock_name, self._lock_key), None)

self._dropped = True

def __del__(self) -> None:
Expand Down
Loading