Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add cache invalidation across workers to module API
Browse files Browse the repository at this point in the history
Signed-off-by: Mathieu Velten <[email protected]>
  • Loading branch information
Mathieu Velten committed Aug 30, 2022
1 parent 4249082 commit 9143cbe
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 19 deletions.
4 changes: 2 additions & 2 deletions scripts-dev/mypy_synapse_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
"synapse.util.caches.descriptors.CachedFunction.__call__"
) or fullname.startswith(
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
):
Expand All @@ -38,7 +38,7 @@ def get_method_signature_hook(


def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.
"""Fixes the `CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except:
Expand Down
16 changes: 15 additions & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
)
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import CachedFunction, cached
from synapse.util.frozenutils import freeze

if TYPE_CHECKING:
Expand Down Expand Up @@ -836,6 +836,20 @@ def run_db_interaction(
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
)

async def invalidate_cache(
self, cached_func: CachedFunction, keys: Tuple[Any, ...]
) -> None:
cached_func.invalidate(keys)
await self._store.send_invalidation_to_replication(
cached_func.__qualname__,
keys,
)

def register_cached_function(self, cached_func: CachedFunction) -> None:
self._store.register_external_cached_function(
cached_func.__qualname__, cached_func
)

def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
) -> None:
Expand Down
6 changes: 4 additions & 2 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _invalidate_state_caches(

def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
) -> None:
) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
Expand All @@ -115,7 +115,7 @@ def _attempt_to_invalidate_cache(
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
return
return False

if key is None:
cache.invalidate_all()
Expand All @@ -125,6 +125,8 @@ def _attempt_to_invalidate_cache(
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
invalidate_method(tuple(key))

return True


def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
Expand Down
29 changes: 22 additions & 7 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.caches.descriptors import CachedFunction
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -91,6 +91,11 @@ def __init__(
else:
self._cache_id_gen = None

self.external_cached_functions = {}

def register_external_cached_function(self, cache_name, func):
self.external_cached_functions[cache_name] = func

async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
Expand Down Expand Up @@ -178,7 +183,11 @@ def process_replication_rows(
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
res = self._attempt_to_invalidate_cache(row.cache_func, row.keys)
if not res:
external_func = self.external_cached_functions[row.cache_func]
if external_func:
external_func.invalidate(row.keys)

super().process_replication_rows(stream_name, instance_name, token, rows)

Expand Down Expand Up @@ -269,17 +278,15 @@ async def invalidate_cache_and_stream(
return

cache_func.invalidate(keys)
await self.db_pool.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
await self.send_invalidation_to_replication(
cache_func.__name__,
keys,
)

def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
cache_func: CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
Expand All @@ -293,7 +300,7 @@ def _invalidate_cache_and_stream(
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)

def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
Expand Down Expand Up @@ -334,6 +341,14 @@ def _invalidate_state_caches_and_stream(
txn, CURRENT_STATE_CACHE_NAME, [room_id]
)

async def send_invalidation_to_replication(self, cache_name, keys):
await self.db_pool.runInteraction(
"send_invalidation_to_replication",
self._send_invalidation_to_replication,
cache_name,
keys,
)

def _send_invalidation_to_replication(
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
Expand Down
14 changes: 7 additions & 7 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
F = TypeVar("F", bound=Callable[..., Any])


class _CachedFunction(Generic[F]):
class CachedFunction(Generic[F]):
invalidate: Any = None
invalidate_all: Any = None
prefill: Any = None
Expand Down Expand Up @@ -242,7 +242,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any:

return ret2

wrapped = cast(_CachedFunction, _wrapped)
wrapped = cast(CachedFunction, _wrapped)
wrapped.cache = cache
obj.__dict__[self.name] = wrapped

Expand Down Expand Up @@ -363,7 +363,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any:

return make_deferred_yieldable(ret)

wrapped = cast(_CachedFunction, _wrapped)
wrapped = cast(CachedFunction, _wrapped)

if self.num_args == 1:
assert not self.tree
Expand Down Expand Up @@ -572,7 +572,7 @@ def cached(
iterable: bool = False,
prune_unread_entries: bool = True,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
) -> Callable[[F], CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
Expand All @@ -585,7 +585,7 @@ def cached(
name=name,
)

return cast(Callable[[F], _CachedFunction[F]], func)
return cast(Callable[[F], CachedFunction[F]], func)


def cachedList(
Expand All @@ -594,7 +594,7 @@ def cachedList(
list_name: str,
num_args: Optional[int] = None,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
) -> Callable[[F], CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments
Expand Down Expand Up @@ -631,7 +631,7 @@ def batch_do_something(self, first_arg, second_args):
name=name,
)

return cast(Callable[[F], _CachedFunction[F]], func)
return cast(Callable[[F], CachedFunction[F]], func)


def _get_cache_key_builder(
Expand Down
112 changes: 112 additions & 0 deletions tests/replication/test_module_cache_invalidation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import synapse
from synapse.module_api import cached

from tests.replication._base import BaseMultiWorkerStreamTestCase

logger = logging.getLogger(__name__)

FIRST_VALUE = "one"
SECOND_VALUE = "two"

KEY = "mykey"


class TestCache:
current_value = FIRST_VALUE

@cached()
async def cached_function(self, user_id: str) -> str:
return self.current_value


class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
]

def test_module_cache_local_invalidation_only(self):
main_cache = TestCache()

self.make_worker_hs("synapse.app.generic_worker")

worker_cache = TestCache()

self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

main_cache.current_value = SECOND_VALUE
worker_cache.current_value = SECOND_VALUE
# No local invalidation yet, should return the cached value on both the main process and the worker
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

# local invalidation on the main process, worker should still return the cached value
main_cache.cached_function.invalidate((KEY,))
self.assertEqual(
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
)
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

def test_module_cache_full_invalidation(self):
# This is supposed to be useless, but something definitively
# get initialized regarding replication there since it doesn't work without it
self.register_user("user", "pass")

main_cache = TestCache()
self.hs.get_module_api().register_cached_function(main_cache.cached_function)

worker_hs = self.make_worker_hs("synapse.app.generic_worker")

worker_cache = TestCache()
worker_hs.get_module_api().register_cached_function(
worker_cache.cached_function
)

self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

main_cache.current_value = SECOND_VALUE
worker_cache.current_value = SECOND_VALUE
# No invalidation yet, should return the cached value on both the main process and the worker
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

# Full invalidation on the main process, should be replicated on the worker that
# should returned the updated value too
self.get_success(
self.hs.get_module_api().invalidate_cache(
main_cache.cached_function, (KEY,)
)
)

self.assertEqual(
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
)
self.assertEqual(
SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

0 comments on commit 9143cbe

Please sign in to comment.