Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to the push mailer. #8882

Merged
merged 1 commit into from
Dec 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8882.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to push module.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ files =
synapse/metrics,
synapse/module_api,
synapse/notifier.py,
synapse/push/mailer.py,
synapse/push/pusherpool.py,
synapse/push/push_rule_evaluator.py,
synapse/replication,
Expand Down
123 changes: 83 additions & 40 deletions synapse/push/mailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,28 @@
import urllib.parse
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Iterable, List, TypeVar
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar

import bleach
import jinja2

from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import (
calculate_room_name,
descriptor_from_member_events,
name_from_member_event,
)
from synapse.types import UserID
from synapse.types import StateMap, UserID
from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

T = TypeVar("T")
Expand Down Expand Up @@ -93,7 +97,13 @@


class Mailer:
def __init__(self, hs, app_name, template_html, template_text):
def __init__(
self,
hs: "HomeServer",
app_name: str,
template_html: jinja2.Template,
template_text: jinja2.Template,
):
self.hs = hs
self.template_html = template_html
self.template_text = template_text
Expand All @@ -108,17 +118,19 @@ def __init__(self, hs, app_name, template_html, template_text):

logger.info("Created Mailer for app_name %s" % app_name)

async def send_password_reset_mail(self, email_address, token, client_secret, sid):
async def send_password_reset_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a password reset link to a user

Args:
email_address (str): Email address we're sending the password
email_address: Email address we're sending the password
reset to
token (str): Unique token generated by the server to verify
token: Unique token generated by the server to verify
the email was received
client_secret (str): Unique token generated by the client to
client_secret: Unique token generated by the client to
group together multiple email sending attempts
sid (str): The generated session ID
sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
Expand All @@ -136,17 +148,19 @@ async def send_password_reset_mail(self, email_address, token, client_secret, si
template_vars,
)

async def send_registration_mail(self, email_address, token, client_secret, sid):
async def send_registration_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a registration confirmation link to a user

Args:
email_address (str): Email address we're sending the registration
email_address: Email address we're sending the registration
link to
token (str): Unique token generated by the server to verify
token: Unique token generated by the server to verify
the email was received
client_secret (str): Unique token generated by the client to
client_secret: Unique token generated by the client to
group together multiple email sending attempts
sid (str): The generated session ID
sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
Expand All @@ -164,18 +178,20 @@ async def send_registration_mail(self, email_address, token, client_secret, sid)
template_vars,
)

async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
async def send_add_threepid_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a validation link to a user for adding a 3pid to their account

Args:
email_address (str): Email address we're sending the validation link to
email_address: Email address we're sending the validation link to

token (str): Unique token generated by the server to verify the email was received
token: Unique token generated by the server to verify the email was received

client_secret (str): Unique token generated by the client to group together
client_secret: Unique token generated by the client to group together
multiple email sending attempts

sid (str): The generated session ID
sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
Expand All @@ -194,16 +210,21 @@ async def send_add_threepid_mail(self, email_address, token, client_secret, sid)
)

async def send_notification_mail(
self, app_id, user_id, email_address, push_actions, reason
):
self,
app_id: str,
user_id: str,
email_address: str,
push_actions: Iterable[Dict[str, Any]],
reason: Dict[str, Any],
) -> None:
"""Send email regarding a user's room notifications"""
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])

notif_events = await self.store.get_events(
[pa["event_id"] for pa in push_actions]
)

notifs_by_room = {}
notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]]
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)

Expand Down Expand Up @@ -262,7 +283,9 @@ async def _fetch_room_state(room_id):

await self.send_email(email_address, summary_text, template_vars)

async def send_email(self, email_address, subject, extra_template_vars):
async def send_email(
self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
) -> None:
"""Send an email with the given information and template text"""
try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
Expand Down Expand Up @@ -315,8 +338,13 @@ async def send_email(self, email_address, subject, extra_template_vars):
)

async def get_room_vars(
self, room_id, user_id, notifs, notif_events, room_state_ids
):
self,
room_id: str,
user_id: str,
notifs: Iterable[Dict[str, Any]],
notif_events: Dict[str, EventBase],
room_state_ids: StateMap[str],
) -> Dict[str, Any]:
# Check if one of the notifs is an invite event for the user.
is_invite = False
for n in notifs:
Expand All @@ -334,7 +362,7 @@ async def get_room_vars(
"notifs": [],
"invite": is_invite,
"link": self.make_room_link(room_id),
}
} # type: Dict[str, Any]

if not is_invite:
for n in notifs:
Expand Down Expand Up @@ -365,7 +393,13 @@ async def get_room_vars(

return room_vars

async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
async def get_notif_vars(
self,
notif: Dict[str, Any],
user_id: str,
notif_event: EventBase,
room_state_ids: StateMap[str],
) -> Dict[str, Any]:
results = await self.store.get_events_around(
notif["room_id"],
notif["event_id"],
Expand All @@ -391,7 +425,9 @@ async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):

return ret

async def get_message_vars(self, notif, event, room_state_ids):
async def get_message_vars(
self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
) -> Optional[Dict[str, Any]]:
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
return None

Expand Down Expand Up @@ -432,7 +468,9 @@ async def get_message_vars(self, notif, event, room_state_ids):

return ret

def add_text_message_vars(self, messagevars, event):
def add_text_message_vars(
self, messagevars: Dict[str, Any], event: EventBase
) -> None:
msgformat = event.content.get("format")

messagevars["format"] = msgformat
Expand All @@ -445,15 +483,18 @@ def add_text_message_vars(self, messagevars, event):
elif body:
messagevars["body_text_html"] = safe_text(body)

return messagevars

def add_image_message_vars(self, messagevars, event):
def add_image_message_vars(
self, messagevars: Dict[str, Any], event: EventBase
) -> None:
messagevars["image_url"] = event.content["url"]

return messagevars

async def make_summary_text(
self, notifs_by_room, room_state_ids, notif_events, user_id, reason
self,
notifs_by_room: Dict[str, List[Dict[str, Any]]],
room_state_ids: Dict[str, StateMap[str]],
notif_events: Dict[str, EventBase],
user_id: str,
reason: Dict[str, Any],
):
if len(notifs_by_room) == 1:
# Only one room has new stuff
Expand Down Expand Up @@ -580,7 +621,7 @@ async def make_summary_text(
"app": self.app_name,
}

def make_room_link(self, room_id):
def make_room_link(self, room_id: str) -> str:
if self.hs.config.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
elif self.app_name == "Vector":
Expand All @@ -590,7 +631,7 @@ def make_room_link(self, room_id):
base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)

def make_notif_link(self, notif):
def make_notif_link(self, notif: Dict[str, str]) -> str:
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
Expand All @@ -606,7 +647,9 @@ def make_notif_link(self, notif):
else:
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])

def make_unsubscribe_link(self, user_id, app_id, email_address):
def make_unsubscribe_link(
self, user_id: str, app_id: str, email_address: str
) -> str:
params = {
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id,
Expand All @@ -620,7 +663,7 @@ def make_unsubscribe_link(self, user_id, app_id, email_address):
)


def safe_markup(raw_html):
def safe_markup(raw_html: str) -> jinja2.Markup:
return jinja2.Markup(
bleach.linkify(
bleach.clean(
Expand All @@ -635,7 +678,7 @@ def safe_markup(raw_html):
)


def safe_text(raw_text):
def safe_text(raw_text: str) -> jinja2.Markup:
"""
Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it.
Expand All @@ -655,7 +698,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
return ret


def string_ordinal_total(s):
def string_ordinal_total(s: str) -> int:
tot = 0
for c in s:
tot += ord(c)
Expand Down