Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add typing info to Notifier #8058

Merged
merged 4 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8058.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `Notifier`.
4 changes: 0 additions & 4 deletions synapse/handlers/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,10 @@ async def get_stream(
timeout=0,
as_client_event=True,
affect_presence=True,
only_keys=None,
room_id=None,
is_guest=False,
):
"""Fetches the events stream for a given user.

If `only_keys` is not None, events from keys will be sent down.
"""

if room_id:
Expand Down Expand Up @@ -93,7 +90,6 @@ async def get_stream(
auth_user,
pagin_config,
timeout,
only_keys=only_keys,
is_guest=is_guest,
explicit_room_id=room_id,
)
Expand Down
131 changes: 83 additions & 48 deletions synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@

import logging
from collections import namedtuple
from typing import Callable, Iterable, List, TypeVar
from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
)

from prometheus_client import Counter

Expand All @@ -24,12 +34,14 @@
import synapse.server
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import StreamToken
from synapse.streams.config import PaginationConfig
from synapse.types import Collection, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
Expand Down Expand Up @@ -77,7 +89,13 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class.
"""

def __init__(self, user_id, rooms, current_token, time_now_ms):
def __init__(
self,
user_id: str,
rooms: Collection[str],
current_token: StreamToken,
time_now_ms: int,
):
self.user_id = user_id
self.rooms = set(rooms)
self.current_token = current_token
Expand All @@ -93,13 +111,13 @@ def __init__(self, user_id, rooms, current_token, time_now_ms):
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())

def notify(self, stream_key, stream_id, time_now_ms):
def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
"""Notify any listeners for this user of a new event from an
event source.
Args:
stream_key(str): The stream the event came from.
stream_id(str): The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds.
stream_key: The stream the event came from.
stream_id: The new id for the stream the event came from.
time_now_ms: The current time in milliseconds.
"""
self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
self.last_notified_token = self.current_token
Expand All @@ -112,7 +130,7 @@ def notify(self, stream_key, stream_id, time_now_ms):
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)

def remove(self, notifier):
def remove(self, notifier: "Notifier"):
""" Remove this listener from all the indexes in the Notifier
it knows about.
"""
Expand All @@ -123,10 +141,10 @@ def remove(self, notifier):

notifier.user_to_user_stream.pop(self.user_id)

def count_listeners(self):
def count_listeners(self) -> int:
return len(self.notify_deferred.observers())

def new_listener(self, token):
def new_listener(self, token: StreamToken) -> _NotificationListener:
"""Returns a deferred that is resolved when there is a new token
greater than the given token.

Expand Down Expand Up @@ -159,14 +177,16 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000

def __init__(self, hs: "synapse.server.HomeServer"):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream]
self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]]

self.hs = hs
self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
self.pending_new_room_events = (
[]
) # type: List[Tuple[int, EventBase, Collection[str]]]

# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
Expand All @@ -178,10 +198,9 @@ def __init__(self, hs: "synapse.server.HomeServer"):
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()

self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
else:
self.federation_sender = None

self.state_handler = hs.get_state_handler()

Expand All @@ -193,12 +212,12 @@ def __init__(self, hs: "synapse.server.HomeServer"):
# when rendering the metrics page, which is likely once per minute at
# most when scraping it.
def count_listeners():
all_user_streams = set()
all_user_streams = set() # type: Set[_NotifierUserStream]

for x in list(self.room_to_user_streams.values()):
all_user_streams |= x
for x in list(self.user_to_user_stream.values()):
all_user_streams.add(x)
for streams in list(self.room_to_user_streams.values()):
all_user_streams |= streams
for stream in list(self.user_to_user_stream.values()):
all_user_streams.add(stream)

return sum(stream.count_listeners() for stream in all_user_streams)

Expand All @@ -223,7 +242,11 @@ def add_replication_callback(self, cb: Callable[[], None]):
self.replication_callbacks.append(cb)

def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[]
self,
event: EventBase,
room_stream_id: int,
max_room_stream_id: int,
extra_users: Collection[str] = [],
):
""" Used by handlers to inform the notifier something has happened
in the room, room event wise.
Expand All @@ -241,11 +264,11 @@ def on_new_room_event(

self.notify_replication()

def _notify_pending_new_room_events(self, max_room_stream_id):
def _notify_pending_new_room_events(self, max_room_stream_id: int):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
max_room_stream_id(int): The highest stream_id below which all
max_room_stream_id: The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
Expand All @@ -258,7 +281,9 @@ def _notify_pending_new_room_events(self, max_room_stream_id):
else:
self._on_new_room_event(event, room_stream_id, extra_users)

def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
def _on_new_room_event(
self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = []
):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
run_as_background_process(
Expand All @@ -275,13 +300,19 @@ def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
)

async def _notify_app_services(self, room_stream_id):
async def _notify_app_services(self, room_stream_id: int):
try:
await self.appservice_handler.notify_interested_services(room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")

def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
def on_new_event(
self,
stream_key: str,
new_token: int,
users: Collection[str] = [],
rooms: Collection[str] = [],
):
""" Used to inform listeners that something has happened event wise.

Will wake up all listeners for the given users and rooms.
Expand All @@ -307,14 +338,19 @@ def on_new_event(self, stream_key, new_token, users=[], rooms=[]):

self.notify_replication()

def on_new_replication_data(self):
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
self.notify_replication()

async def wait_for_events(
self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
):
self,
user_id: str,
timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids=None,
from_token=StreamToken.START,
) -> T:
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
Expand Down Expand Up @@ -377,19 +413,16 @@ async def wait_for_events(

async def get_events_for(
self,
user,
pagination_config,
timeout,
only_keys=None,
is_guest=False,
explicit_room_id=None,
):
user: UserID,
pagination_config: PaginationConfig,
timeout: int,
is_guest: bool = False,
explicit_room_id: str = None,
) -> EventStreamResult:
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.

If `only_keys` is not None, events from keys will be sent down.

If explicit_room_id is not set, the user's joined rooms will be polled
for events.
If explicit_room_id is set, that room will be polled for events only if
Expand All @@ -404,11 +437,13 @@ async def get_events_for(
room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined

async def check_for_updates(before_token, after_token):
async def check_for_updates(
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
if not after_token.is_after(before_token):
return EventStreamResult([], (from_token, from_token))

events = []
events = [] # type: List[EventBase]
end_token = from_token

for name, source in self.event_sources.sources.items():
Expand All @@ -417,8 +452,6 @@ async def check_for_updates(before_token, after_token):
after_id = getattr(after_token, keyname)
if before_id == after_id:
continue
if only_keys and name not in only_keys:
continue

new_events, new_key = await source.get_new_events(
user=user,
Expand Down Expand Up @@ -476,7 +509,9 @@ async def check_for_updates(before_token, after_token):

return result

async def _get_room_ids(self, user, explicit_room_id):
async def _get_room_ids(
self, user: UserID, explicit_room_id: Optional[str]
) -> Tuple[Collection[str], bool]:
joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
Expand All @@ -486,7 +521,7 @@ async def _get_room_ids(self, user, explicit_room_id):
raise AuthError(403, "Non-joined access not allowed")
return joined_room_ids, True

async def _is_world_readable(self, room_id):
async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
Expand All @@ -496,7 +531,7 @@ async def _is_world_readable(self, room_id):
return False

@log_function
def remove_expired_streams(self):
def remove_expired_streams(self) -> None:
time_now_ms = self.clock.time_msec()
expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
Expand All @@ -510,21 +545,21 @@ def remove_expired_streams(self):
expired_stream.remove(self)

@log_function
def _register_with_keys(self, user_stream):
def _register_with_keys(self, user_stream: _NotifierUserStream):
self.user_to_user_stream[user_stream.user_id] = user_stream

for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)

def _user_joined_room(self, user_id, room_id):
def _user_joined_room(self, user_id: str, room_id: str):
new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream)
new_user_stream.rooms.add(room_id)

def notify_replication(self):
def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()
Expand Down
6 changes: 6 additions & 0 deletions synapse/server.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.typing import FollowerTypingHandler
from synapse.replication.tcp.streams import Stream
from synapse.streams.events import EventSources

class HomeServer(object):
@property
Expand Down Expand Up @@ -153,3 +155,7 @@ class HomeServer(object):
pass
def get_typing_handler(self) -> FollowerTypingHandler:
pass
def get_event_sources(self) -> EventSources:
pass
def get_application_service_handler(self):
return ApplicationServicesHandler(self)
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ commands = mypy \
synapse/logging/ \
synapse/metrics \
synapse/module_api \
synapse/notifier.py \
synapse/push/pusherpool.py \
synapse/push/push_rule_evaluator.py \
synapse/replication \
Expand Down