From ce6819a7015847084208ff09758c5a13c1c4c429 Mon Sep 17 00:00:00 2001 From: John-Scott Atlakson <24574+jsma@users.noreply.github.com> Date: Thu, 19 Aug 2021 03:16:00 -0700 Subject: [PATCH 01/28] Fix typo in release notes (#10646) Ubuntu 20.10 was not an LTS release Signed-off-by: John-Scott Atlakson 24574+jsma@users.noreply.github.com --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 01766af39cc6..cad9423ebd1a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -48,7 +48,7 @@ Improved Documentation Deprecations and Removals ------------------------- -- No longer build `.deb` packages for Ubuntu 20.10 LTS Groovy Gorilla, which has now EOLed. ([\#10588](https://github.com/matrix-org/synapse/issues/10588)) +- No longer build `.deb` packages for Ubuntu 20.10 Groovy Gorilla, which has now EOLed. ([\#10588](https://github.com/matrix-org/synapse/issues/10588)) - The `template_dir` configuration settings in the `sso`, `account_validity` and `email` sections of the configuration file are now deprecated in favour of the global `templates.custom_template_directory` setting. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. ([\#10596](https://github.com/matrix-org/synapse/issues/10596)) From 5cda75fedef3dd02d3b456231be0a1b4bff2a31a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 20 Aug 2021 07:17:50 -0400 Subject: [PATCH 02/28] Set room version 8 as preferred for restricted rooms. (#10571) --- changelog.d/10571.feature | 1 + synapse/api/room_versions.py | 2 +- synapse/config/experimental.py | 2 +- tests/rest/client/v2_alpha/test_capabilities.py | 4 ++-- 4 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10571.feature diff --git a/changelog.d/10571.feature b/changelog.d/10571.feature new file mode 100644 index 000000000000..0da318cd5b34 --- /dev/null +++ b/changelog.d/10571.feature @@ -0,0 +1 @@ +Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version for restricted rooms. diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 11280c446220..8abcdfd4fd9a 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -293,7 +293,7 @@ class RoomVersionCapability: ), RoomVersionCapability( "restricted", - None, + RoomVersions.V8, lambda room_version: room_version.msc3083_join_rules, ), ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b918fb15b04e..907df9591a85 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -37,7 +37,7 @@ def read_config(self, config: JsonDict, **kwargs): self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) # MSC3244 (room version capabilities) - self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False) + self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index ad83b3d2ff58..13b3c5f499b7 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -102,7 +102,8 @@ def test_get_change_password_capabilities_password_disabled(self): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_does_not_include_msc3244_fields_by_default(self): + @override_config({"experimental_features": {"msc3244_enabled": False}}) + def test_get_does_not_include_msc3244_fields_when_disabled(self): localpart = "user" password = "pass" user = self.register_user(localpart, password) @@ -120,7 +121,6 @@ def test_get_does_not_include_msc3244_fields_by_default(self): "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] ) - @override_config({"experimental_features": {"msc3244_enabled": True}}) def test_get_does_include_msc3244_fields_when_enabled(self): localpart = "user" password = "pass" From 31dac7ffeeb02f68d1dbe068fd241239e02208dc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Aug 2021 08:00:25 -0400 Subject: [PATCH 03/28] Do not include stack traces for known exceptions when trying multiple federation destinations. (#10662) --- changelog.d/10662.misc | 1 + synapse/federation/federation_client.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10662.misc diff --git a/changelog.d/10662.misc b/changelog.d/10662.misc new file mode 100644 index 000000000000..593f9ceaad5a --- /dev/null +++ b/changelog.d/10662.misc @@ -0,0 +1 @@ +Do not print out stack traces for network errors when fetching data over federation. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 29979414e3d7..44d9e8a5c734 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -43,6 +43,7 @@ Codes, FederationDeniedError, HttpResponseException, + RequestSendFailed, SynapseError, UnsupportedRoomVersionError, ) @@ -558,7 +559,11 @@ async def _try_destination_list( try: return await callback(destination) - except InvalidResponseError as e: + except ( + RequestSendFailed, + InvalidResponseError, + NotRetryingDestination, + ) as e: logger.warning("Failed to %s via %s: %s", description, destination, e) except UnsupportedRoomVersionError: raise From 2af6d31b78109a989e27128ac655990c35b29d62 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Aug 2021 08:14:17 -0400 Subject: [PATCH 04/28] Addtional type hints for the REST servlets. (#10665) --- changelog.d/10665.misc | 1 + synapse/rest/client/account_validity.py | 39 ++++------ synapse/rest/client/capabilities.py | 3 +- synapse/rest/client/directory.py | 78 ++++++++++++------- synapse/rest/client/pusher.py | 22 ++++-- synapse/rest/client/read_marker.py | 15 +++- .../rest/client/room_upgrade_rest_servlet.py | 18 +++-- synapse/rest/client/shared_rooms.py | 16 +++- synapse/rest/client/tags.py | 25 ++++-- synapse/rest/client/thirdparty.py | 36 ++++++--- synapse/rest/client/tokenrefresh.py | 14 +++- synapse/rest/client/user_directory.py | 17 ++-- synapse/rest/client/versions.py | 14 +++- synapse/rest/client/voip.py | 13 +++- 14 files changed, 204 insertions(+), 107 deletions(-) create mode 100644 changelog.d/10665.misc diff --git a/changelog.d/10665.misc b/changelog.d/10665.misc new file mode 100644 index 000000000000..39a37b90b1b3 --- /dev/null +++ b/changelog.d/10665.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py index 3ebe40186153..6c24b96c547d 100644 --- a/synapse/rest/client/account_validity.py +++ b/synapse/rest/client/account_validity.py @@ -13,24 +13,27 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import SynapseError -from synapse.http.server import respond_with_html -from synapse.http.servlet import RestServlet +from twisted.web.server import Request + +from synapse.http.server import HttpServer, respond_with_html +from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class AccountValidityRenewServlet(RestServlet): PATTERNS = client_patterns("/account_validity/renew$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -46,18 +49,14 @@ def __init__(self, hs): hs.config.account_validity.account_validity_invalid_token_template ) - async def on_GET(self, request): - if b"token" not in request.args: - raise SynapseError(400, "Missing renewal token") - renewal_token = request.args[b"token"][0] + async def on_GET(self, request: Request) -> None: + renewal_token = parse_string(request, "token", required=True) ( token_valid, token_stale, expiration_ts, - ) = await self.account_activity_handler.renew_account( - renewal_token.decode("utf8") - ) + ) = await self.account_activity_handler.renew_account(renewal_token) if token_valid: status_code = 200 @@ -77,11 +76,7 @@ async def on_GET(self, request): class AccountValiditySendMailServlet(RestServlet): PATTERNS = client_patterns("/account_validity/send_mail$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -91,7 +86,7 @@ def __init__(self, hs): hs.config.account_validity.account_validity_renew_by_email_enabled ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() await self.account_activity_handler.send_renewal_email_to_user(user_id) @@ -99,6 +94,6 @@ async def on_POST(self, request): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AccountValidityRenewServlet(hs).register(http_server) AccountValiditySendMailServlet(hs).register(http_server) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 093549512ebc..65b3b5ce2cc5 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest from synapse.types import JsonDict @@ -75,5 +76,5 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: return 200, response -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: CapabilitiesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index ffa075c8e5f6..ee247e3d1e0d 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.server import Request from synapse.api.errors import ( AuthError, @@ -22,14 +24,19 @@ NotFoundError, SynapseError, ) +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import RoomAlias +from synapse.types import JsonDict, RoomAlias + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ClientDirectoryServer(hs).register(http_server) ClientDirectoryListServer(hs).register(http_server) ClientAppserviceDirectoryListServer(hs).register(http_server) @@ -38,21 +45,23 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(RestServlet): PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_alias): - room_alias = RoomAlias.from_string(room_alias) + async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]: + room_alias_obj = RoomAlias.from_string(room_alias) - res = await self.directory_handler.get_association(room_alias) + res = await self.directory_handler.get_association(room_alias_obj) return 200, res - async def on_PUT(self, request, room_alias): - room_alias = RoomAlias.from_string(room_alias) + async def on_PUT( + self, request: SynapseRequest, room_alias: str + ) -> Tuple[int, JsonDict]: + room_alias_obj = RoomAlias.from_string(room_alias) content = parse_json_object_from_request(request) if "room_id" not in content: @@ -61,7 +70,7 @@ async def on_PUT(self, request, room_alias): ) logger.debug("Got content: %s", content) - logger.debug("Got room name: %s", room_alias.to_string()) + logger.debug("Got room name: %s", room_alias_obj.to_string()) room_id = content["room_id"] servers = content["servers"] if "servers" in content else None @@ -78,22 +87,25 @@ async def on_PUT(self, request, room_alias): requester = await self.auth.get_user_by_req(request) await self.directory_handler.create_association( - requester, room_alias, room_id, servers + requester, room_alias_obj, room_id, servers ) return 200, {} - async def on_DELETE(self, request, room_alias): + async def on_DELETE( + self, request: SynapseRequest, room_alias: str + ) -> Tuple[int, JsonDict]: + room_alias_obj = RoomAlias.from_string(room_alias) + try: service = self.auth.get_appservice_by_req(request) - room_alias = RoomAlias.from_string(room_alias) await self.directory_handler.delete_appservice_association( - service, room_alias + service, room_alias_obj ) logger.info( "Application service at %s deleted alias %s", service.url, - room_alias.to_string(), + room_alias_obj.to_string(), ) return 200, {} except InvalidClientCredentialsError: @@ -103,12 +115,10 @@ async def on_DELETE(self, request, room_alias): requester = await self.auth.get_user_by_req(request) user = requester.user - room_alias = RoomAlias.from_string(room_alias) - - await self.directory_handler.delete_association(requester, room_alias) + await self.directory_handler.delete_association(requester, room_alias_obj) logger.info( - "User %s deleted alias %s", user.to_string(), room_alias.to_string() + "User %s deleted alias %s", user.to_string(), room_alias_obj.to_string() ) return 200, {} @@ -117,20 +127,22 @@ async def on_DELETE(self, request, room_alias): class ClientDirectoryListServer(RestServlet): PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: room = await self.store.get_room(room_id) if room is None: raise NotFoundError("Unknown room") return 200, {"visibility": "public" if room["is_public"] else "private"} - async def on_PUT(self, request, room_id): + async def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -142,7 +154,9 @@ async def on_PUT(self, request, room_id): return 200, {} - async def on_DELETE(self, request, room_id): + async def on_DELETE( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await self.directory_handler.edit_published_room_list( @@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet): "/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() - def on_PUT(self, request, network_id, room_id): + async def on_PUT( + self, request: SynapseRequest, network_id: str, room_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) visibility = content.get("visibility", "public") - return self._edit(request, network_id, room_id, visibility) + return await self._edit(request, network_id, room_id, visibility) - def on_DELETE(self, request, network_id, room_id): - return self._edit(request, network_id, room_id, "private") + async def on_DELETE( + self, request: SynapseRequest, network_id: str, room_id: str + ) -> Tuple[int, JsonDict]: + return await self._edit(request, network_id, room_id, "private") - async def _edit(self, request, network_id, room_id, visibility): + async def _edit( + self, request: SynapseRequest, network_id: str, room_id: str, visibility: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if not requester.app_service: raise AuthError( diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 84619c5e4184..98604a93887f 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -13,17 +13,23 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.http.server import respond_with_html_bytes +from synapse.http.server import HttpServer, respond_with_html_bytes from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.push import PusherConfigException from synapse.rest.client._base import client_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -31,12 +37,12 @@ class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = requester.user @@ -50,14 +56,14 @@ async def on_GET(self, request): class PushersSetRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/set$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = requester.user @@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/remove$", v1=True) SUCCESS_HTML = b"You have been unsubscribed" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.notifier = hs.get_notifier() self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request, rights="delete_pusher") user = requester.user @@ -165,7 +171,7 @@ async def on_GET(self, request): return None -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushersRestServlet(hs).register(http_server) PushersSetRestServlet(hs).register(http_server) PushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 027f8b81fa93..43c04fac6fdb 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -13,27 +13,36 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ReadMarkerRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await self.presence_handler.bump_presence_active_time(requester.user) @@ -70,5 +79,5 @@ async def on_POST(self, request, room_id): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReadMarkerRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py index 6d1b083acb47..6a7792e18b2e 100644 --- a/synapse/rest/client/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/room_upgrade_rest_servlet.py @@ -13,18 +13,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from synapse.util import stringutils from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet): } Creates a new room and shuts down the old one. Returns the ID of the new room. - - Args: - hs (synapse.server.HomeServer): """ PATTERNS = client_patterns( @@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet): "/rooms/(?P[^/]*)/upgrade$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self._hs = hs self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -84,5 +90,5 @@ async def on_POST(self, request, room_id): return 200, ret -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomUpgradeRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index d2e7f04b406c..1d90493eb082 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet -from synapse.types import UserID +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet): releases=(), # This is an unstable feature ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.user_directory_active = hs.config.update_user_directory - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: if not self.user_directory_active: raise SynapseError( @@ -63,5 +71,5 @@ async def on_GET(self, request, user_id): return 200, {"joined": list(rooms)} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py index c14f83be1878..c88cb9367c5f 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py @@ -13,12 +13,19 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -29,12 +36,14 @@ class TagListServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, user_id, room_id): + async def on_GET( + self, request: SynapseRequest, user_id: str, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get tags for other users.") @@ -54,12 +63,14 @@ class TagServlet(RestServlet): "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.handler = hs.get_account_data_handler() - async def on_PUT(self, request, user_id, room_id, tag): + async def on_PUT( + self, request: SynapseRequest, user_id: str, room_id: str, tag: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") @@ -70,7 +81,9 @@ async def on_PUT(self, request, user_id, room_id, tag): return 200, {} - async def on_DELETE(self, request, user_id, room_id, tag): + async def on_DELETE( + self, request: SynapseRequest, user_id: str, room_id: str, tag: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") @@ -80,6 +93,6 @@ async def on_DELETE(self, request, user_id, room_id, tag): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: TagListServlet(hs).register(http_server) TagServlet(hs).register(http_server) diff --git a/synapse/rest/client/thirdparty.py b/synapse/rest/client/thirdparty.py index b5c67c9bb67e..b895c73acf2c 100644 --- a/synapse/rest/client/thirdparty.py +++ b/synapse/rest/client/thirdparty.py @@ -12,27 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import TYPE_CHECKING, Dict, List, Tuple from synapse.api.constants import ThirdPartyEntityKind +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ThirdPartyProtocolsServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocols") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) protocols = await self.appservice_handler.get_3pe_protocols() @@ -42,13 +48,15 @@ async def on_GET(self, request): class ThirdPartyProtocolServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - async def on_GET(self, request, protocol): + async def on_GET( + self, request: SynapseRequest, protocol: str + ) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) protocols = await self.appservice_handler.get_3pe_protocols( @@ -63,16 +71,18 @@ async def on_GET(self, request, protocol): class ThirdPartyUserServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - async def on_GET(self, request, protocol): + async def on_GET( + self, request: SynapseRequest, protocol: str + ) -> Tuple[int, List[JsonDict]]: await self.auth.get_user_by_req(request, allow_guest=True) - fields = request.args + fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment] fields.pop(b"access_token", None) results = await self.appservice_handler.query_3pe( @@ -85,16 +95,18 @@ async def on_GET(self, request, protocol): class ThirdPartyLocationServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - async def on_GET(self, request, protocol): + async def on_GET( + self, request: SynapseRequest, protocol: str + ) -> Tuple[int, List[JsonDict]]: await self.auth.get_user_by_req(request, allow_guest=True) - fields = request.args + fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment] fields.pop(b"access_token", None) results = await self.appservice_handler.query_3pe( @@ -104,7 +116,7 @@ async def on_GET(self, request, protocol): return 200, results -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ThirdPartyProtocolsServlet(hs).register(http_server) ThirdPartyProtocolServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server) diff --git a/synapse/rest/client/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py index b2f858545cbe..c8c3b25bd36f 100644 --- a/synapse/rest/client/tokenrefresh.py +++ b/synapse/rest/client/tokenrefresh.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + +from twisted.web.server import Request + from synapse.api.errors import AuthError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + class TokenRefreshRestServlet(RestServlet): """ @@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet): PATTERNS = client_patterns("/tokenrefresh") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> None: raise AuthError(403, "tokenrefresh is no longer supported.") -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index 7e8912f0b919..885281111438 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -13,29 +13,32 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class UserDirectorySearchRestServlet(RestServlet): PATTERNS = client_patterns("/user_directory/search$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Searches for users in directory Returns: @@ -75,5 +78,5 @@ async def on_POST(self, request): return 200, results -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserDirectorySearchRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index fa2e4e9cba48..a1a815cf8256 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -17,9 +17,17 @@ import logging import re +from typing import TYPE_CHECKING, Tuple + +from twisted.web.server import Request from synapse.api.constants import RoomCreationPreset +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -27,7 +35,7 @@ class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config @@ -45,7 +53,7 @@ def __init__(self, hs): in self.config.encryption_enabled_by_default_for_room_presets ) - def on_GET(self, request): + def on_GET(self, request: Request) -> Tuple[int, JsonDict]: return ( 200, { @@ -89,5 +97,5 @@ def on_GET(self, request): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: VersionsRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py index f53020520d37..9d46ed3af3f5 100644 --- a/synapse/rest/client/voip.py +++ b/synapse/rest/client/voip.py @@ -15,20 +15,27 @@ import base64 import hashlib import hmac +from typing import TYPE_CHECKING, Tuple +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer class VoipRestServlet(RestServlet): PATTERNS = client_patterns("/voip/turnServer$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req( request, self.hs.config.turn_allow_guests ) @@ -69,5 +76,5 @@ async def on_GET(self, request): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: VoipRestServlet(hs).register(http_server) From bd7d398b05aaa18d5b0629153ababeea7539256c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Aug 2021 08:14:42 -0400 Subject: [PATCH 05/28] Additional type hints for the sync REST servlet. (#10666) --- changelog.d/10666.misc | 1 + synapse/handlers/sync.py | 21 +++--- synapse/rest/client/sync.py | 132 ++++++++++++++++++++++-------------- 3 files changed, 93 insertions(+), 61 deletions(-) create mode 100644 changelog.d/10666.misc diff --git a/changelog.d/10666.misc b/changelog.d/10666.misc new file mode 100644 index 000000000000..39a37b90b1b3 --- /dev/null +++ b/changelog.d/10666.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2203c45dcc9a..86c3c7f0df50 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -30,6 +30,7 @@ from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.filtering import FilterCollection +from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging.context import current_context @@ -231,7 +232,7 @@ class SyncResult: """ next_batch: StreamToken - presence: List[JsonDict] + presence: List[UserPresenceState] account_data: List[JsonDict] joined: List[JoinedSyncResult] invited: List[InvitedSyncResult] @@ -2177,14 +2178,14 @@ class SyncResultBuilder: joined_room_ids: List of rooms the user is joined to # The following mirror the fields in a sync response - presence (list) - account_data (list) - joined (list[JoinedSyncResult]) - invited (list[InvitedSyncResult]) - knocked (list[KnockedSyncResult]) - archived (list[ArchivedSyncResult]) - groups (GroupsSyncResult|None) - to_device (list) + presence + account_data + joined + invited + knocked + archived + groups + to_device """ sync_config: SyncConfig @@ -2193,7 +2194,7 @@ class SyncResultBuilder: now_token: StreamToken joined_room_ids: FrozenSet[str] - presence: List[JsonDict] = attr.Factory(list) + presence: List[UserPresenceState] = attr.Factory(list) account_data: List[JsonDict] = attr.Factory(list) joined: List[JoinedSyncResult] = attr.Factory(list) invited: List[InvitedSyncResult] = attr.Factory(list) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index e18f4d01b375..65c37be3e96c 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -14,17 +14,26 @@ import itertools import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection +from synapse.api.presence import UserPresenceState from synapse.events.utils import ( format_event_for_client_v2_without_room_id, format_event_raw, ) from synapse.handlers.presence import format_user_presence_state -from synapse.handlers.sync import KnockedSyncResult, SyncConfig +from synapse.handlers.sync import ( + ArchivedSyncResult, + InvitedSyncResult, + JoinedSyncResult, + KnockedSyncResult, + SyncConfig, + SyncResult, +) +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.types import JsonDict, StreamToken @@ -192,6 +201,8 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: return 200, {} time_now = self.clock.time_msec() + # We know that the the requester has an access token since appservices + # cannot use sync. response_content = await self.encode_response( time_now, sync_result, requester.access_token_id, filter_collection ) @@ -199,7 +210,13 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: logger.debug("Event formatting complete") return 200, response_content - async def encode_response(self, time_now, sync_result, access_token_id, filter): + async def encode_response( + self, + time_now: int, + sync_result: SyncResult, + access_token_id: Optional[int], + filter: FilterCollection, + ) -> JsonDict: logger.debug("Formatting events in sync response") if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id @@ -234,7 +251,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter): logger.debug("building sync response dict") - response: dict = defaultdict(dict) + response: JsonDict = defaultdict(dict) response["next_batch"] = await sync_result.next_batch.to_string(self.store) if sync_result.account_data: @@ -274,6 +291,8 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter): if archived: response["rooms"][Membership.LEAVE] = archived + # By the time we get here groups is no longer optional. + assert sync_result.groups is not None if sync_result.groups.join: response["groups"][Membership.JOIN] = sync_result.groups.join if sync_result.groups.invite: @@ -284,7 +303,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter): return response @staticmethod - def encode_presence(events, time_now): + def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict: return { "events": [ { @@ -299,25 +318,27 @@ def encode_presence(events, time_now): } async def encode_joined( - self, rooms, time_now, token_id, event_fields, event_formatter - ): + self, + rooms: List[JoinedSyncResult], + time_now: int, + token_id: Optional[int], + event_fields: List[str], + event_formatter: Callable[[JsonDict], JsonDict], + ) -> JsonDict: """ Encode the joined rooms in a sync result Args: - rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync - results for rooms this user is joined to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing + rooms: list of sync results for rooms this user is joined to + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_fields(list): List of event fields to include. If empty, + event_fields: List of event fields to include. If empty, all fields will be returned. - event_formatter (func[dict]): function to convert from federation format + event_formatter: function to convert from federation format to client format Returns: - dict[str, dict[str, object]]: the joined rooms list, in our - response format + The joined rooms list, in our response format """ joined = {} for room in rooms: @@ -332,23 +353,26 @@ async def encode_joined( return joined - async def encode_invited(self, rooms, time_now, token_id, event_formatter): + async def encode_invited( + self, + rooms: List[InvitedSyncResult], + time_now: int, + token_id: Optional[int], + event_formatter: Callable[[JsonDict], JsonDict], + ) -> JsonDict: """ Encode the invited rooms in a sync result Args: - rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of - sync results for rooms this user is invited to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing + rooms: list of sync results for rooms this user is invited to + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_formatter (func[dict]): function to convert from federation format + event_formatter: function to convert from federation format to client format Returns: - dict[str, dict[str, object]]: the invited rooms list, in our - response format + The invited rooms list, in our response format """ invited = {} for room in rooms: @@ -371,7 +395,7 @@ async def encode_knocked( self, rooms: List[KnockedSyncResult], time_now: int, - token_id: int, + token_id: Optional[int], event_formatter: Callable[[Dict], Dict], ) -> Dict[str, Dict[str, Any]]: """ @@ -422,25 +446,26 @@ async def encode_knocked( return knocked async def encode_archived( - self, rooms, time_now, token_id, event_fields, event_formatter - ): + self, + rooms: List[ArchivedSyncResult], + time_now: int, + token_id: Optional[int], + event_fields: List[str], + event_formatter: Callable[[JsonDict], JsonDict], + ) -> JsonDict: """ Encode the archived rooms in a sync result Args: - rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of - sync results for rooms this user is joined to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing + rooms: list of sync results for rooms this user is joined to + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_fields(list): List of event fields to include. If empty, + event_fields: List of event fields to include. If empty, all fields will be returned. - event_formatter (func[dict]): function to convert from federation format - to client format + event_formatter: function to convert from federation format to client format Returns: - dict[str, dict[str, object]]: The invited rooms list, in our - response format + The archived rooms list, in our response format """ joined = {} for room in rooms: @@ -456,23 +481,27 @@ async def encode_archived( return joined async def encode_room( - self, room, time_now, token_id, joined, only_fields, event_formatter - ): + self, + room: Union[JoinedSyncResult, ArchivedSyncResult], + time_now: int, + token_id: Optional[int], + joined: bool, + only_fields: Optional[List[str]], + event_formatter: Callable[[JsonDict], JsonDict], + ) -> JsonDict: """ Args: - room (JoinedSyncResult|ArchivedSyncResult): sync result for a - single room - time_now (int): current time - used as a baseline for age - calculations - token_id (int): ID of the user's auth token - used for namespacing + room: sync result for a single room + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs - joined (bool): True if the user is joined to this room - will mean + joined: True if the user is joined to this room - will mean we handle ephemeral events - only_fields(list): Optional. The list of event fields to include. - event_formatter (func[dict]): function to convert from federation format + only_fields: Optional. The list of event fields to include. + event_formatter: function to convert from federation format to client format Returns: - dict[str, object]: the room, encoded in our response format + The room, encoded in our response format """ def serialize(events): @@ -508,7 +537,7 @@ def serialize(events): account_data = room.account_data - result = { + result: JsonDict = { "timeline": { "events": serialized_timeline, "prev_batch": await room.timeline.prev_batch.to_string(self.store), @@ -519,6 +548,7 @@ def serialize(events): } if joined: + assert isinstance(room, JoinedSyncResult) ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications @@ -528,5 +558,5 @@ def serialize(events): return result -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) From 2efc838f050f0608f6648c0235eaade813d66f08 Mon Sep 17 00:00:00 2001 From: Dan Callahan Date: Mon, 23 Aug 2021 14:06:49 +0100 Subject: [PATCH 06/28] Avoid duplicate issues from Twisted trunk failures (#10672) Setting `update_existing: true` in the `create-an-issue` GitHub Action will avoid opening duplicate issues if an open issue already exists with an identical title. If no open issues match the title, then a new issue will be created. This helps avoid spamming our issue tracker should there be a failure when testing against Twisted's trunk. This PR also pins the SHA of the `create-an-issue` action to mitigate the risk of a malicious actor gaining access to JasonEtco's account. See GitHub's page on security hardening third party actions for more: https://docs.github.com/en/actions/learn-github-actions/security-hardening-for-github-actions#using-third-party-actions Signed-off-by: Dan Callahan --- .github/workflows/twisted_trunk.yml | 3 ++- changelog.d/10672.misc | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10672.misc diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index 0bf77905458e..b5c729888f57 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -82,8 +82,9 @@ jobs: steps: - uses: actions/checkout@v2 - - uses: JasonEtco/create-an-issue@v2 + - uses: JasonEtco/create-an-issue@5d9504915f79f9cc6d791934b8ef34f2353dd74d # v2.5.0, 2020-12-06 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: + update_existing: true filename: .ci/twisted_trunk_build_failed_issue_template.md diff --git a/changelog.d/10672.misc b/changelog.d/10672.misc new file mode 100644 index 000000000000..7104c121e02a --- /dev/null +++ b/changelog.d/10672.misc @@ -0,0 +1 @@ +Run a nightly CI build against Twisted trunk. From 3e83f97154e12fc50ccf2d8b4a01007cff012c47 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 23 Aug 2021 14:58:31 +0100 Subject: [PATCH 07/28] Fix the titles in the OIDC documentation (#10639) * Fix the titles in the OIDC documentation Having them as links broke the table-of-contents rendering in mdbook. Plus there's no reason for only some of the provider titles to be links. * Changelog * Add link to google idp docs --- changelog.d/10639.doc | 1 + docs/openid.md | 16 ++++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10639.doc diff --git a/changelog.d/10639.doc b/changelog.d/10639.doc new file mode 100644 index 000000000000..acbac4aad8ec --- /dev/null +++ b/changelog.d/10639.doc @@ -0,0 +1 @@ +Fix some of the titles not rendering in the OIDC documentation. diff --git a/docs/openid.md b/docs/openid.md index f685fd551acc..f121bc8a6e3b 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -79,7 +79,7 @@ oidc_providers: display_name_template: "{{ user.name }}" ``` -### [Dex][dex-idp] +### Dex [Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider. Although it is designed to help building a full-blown provider with an @@ -117,7 +117,7 @@ oidc_providers: localpart_template: "{{ user.name }}" display_name_template: "{{ user.name|capitalize }}" ``` -### [Keycloak][keycloak-idp] +### Keycloak [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. @@ -166,7 +166,9 @@ oidc_providers: localpart_template: "{{ user.preferred_username }}" display_name_template: "{{ user.name }}" ``` -### [Auth0][auth0] +### Auth0 + +[Auth0][auth0] is a hosted SaaS IdP solution. 1. Create a regular web application for Synapse 2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/client/oidc/callback` @@ -209,7 +211,7 @@ oidc_providers: ### GitHub -GitHub is a bit special as it is not an OpenID Connect compliant provider, but +[GitHub][github-idp] is a bit special as it is not an OpenID Connect compliant provider, but just a regular OAuth2 provider. The [`/user` API endpoint](https://developer.github.com/v3/users/#get-the-authenticated-user) @@ -242,11 +244,13 @@ oidc_providers: display_name_template: "{{ user.name }}" ``` -### [Google][google-idp] +### Google + +[Google][google-idp] is an OpenID certified authentication and authorisation provider. 1. Set up a project in the Google API Console (see https://developers.google.com/identity/protocols/oauth2/openid-connect#appsetup). -2. add an "OAuth Client ID" for a Web Application under "Credentials". +2. Add an "OAuth Client ID" for a Web Application under "Credentials". 3. Copy the Client ID and Client Secret, and add the following to your synapse config: ```yaml oidc_providers: From 0c1d6f65d7c65efd8491adf4efc2620148e2841a Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Mon, 23 Aug 2021 16:25:33 +0100 Subject: [PATCH 08/28] Enforce the max length for per-room display names / avatar URLs. (#10654) To match the maximum lengths allowed for profile data. --- changelog.d/10654.bugfix | 1 + synapse/handlers/room_member.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10654.bugfix diff --git a/changelog.d/10654.bugfix b/changelog.d/10654.bugfix new file mode 100644 index 000000000000..b0bd78453fab --- /dev/null +++ b/changelog.d/10654.bugfix @@ -0,0 +1 @@ +Enforce the maximum length for per-room display names and avatar URLs. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ba131962185f..401b84aad1eb 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -36,6 +36,7 @@ from synapse.event_auth import get_named_level, get_power_level_event from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.types import ( JsonDict, Requester, @@ -79,7 +80,7 @@ def __init__(self, hs: "HomeServer"): self.account_data_handler = hs.get_account_data_handler() self.event_auth_handler = hs.get_event_auth_handler() - self.member_linearizer = Linearizer(name="member") + self.member_linearizer: Linearizer = Linearizer(name="member") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() @@ -556,6 +557,20 @@ async def update_membership_locked( content.pop("displayname", None) content.pop("avatar_url", None) + if len(content.get("displayname") or "") > MAX_DISPLAYNAME_LEN: + raise SynapseError( + 400, + f"Displayname is too long (max {MAX_DISPLAYNAME_LEN})", + errcode=Codes.BAD_JSON, + ) + + if len(content.get("avatar_url") or "") > MAX_AVATAR_URL_LEN: + raise SynapseError( + 400, + f"Avatar URL is too long (max {MAX_AVATAR_URL_LEN})", + errcode=Codes.BAD_JSON, + ) + effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" From 86415f162d5dca9f38054a76e00130d947f2e8e2 Mon Sep 17 00:00:00 2001 From: Hugo DELVAL Date: Mon, 23 Aug 2021 19:12:36 +0200 Subject: [PATCH 09/28] doc: add django-oauth-toolkit to oidc doc (#10192) Signed-off-by: Hugo Delval --- changelog.d/10192.doc | 1 + docs/openid.md | 48 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 changelog.d/10192.doc diff --git a/changelog.d/10192.doc b/changelog.d/10192.doc new file mode 100644 index 000000000000..3dd00537e8d9 --- /dev/null +++ b/changelog.d/10192.doc @@ -0,0 +1 @@ +Add documentation on how to connect Django with synapse using oidc and django-oauth-toolkit. Contributed by @HugoDelval. diff --git a/docs/openid.md b/docs/openid.md index f121bc8a6e3b..49180eec5293 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -450,3 +450,51 @@ The synapse config will look like this: config: email_template: "{{ user.email }}" ``` + +## Django OAuth Toolkit + +[django-oauth-toolkit](https://github.com/jazzband/django-oauth-toolkit) is a +Django application providing out of the box all the endpoints, data and logic +needed to add OAuth2 capabilities to your Django projects. It supports +[OpenID Connect too](https://django-oauth-toolkit.readthedocs.io/en/latest/oidc.html). + +Configuration on Django's side: + +1. Add an application: https://example.com/admin/oauth2_provider/application/add/ and choose parameters like this: +* `Redirect uris`: https://synapse.example.com/_synapse/client/oidc/callback +* `Client type`: `Confidential` +* `Authorization grant type`: `Authorization code` +* `Algorithm`: `HMAC with SHA-2 256` +2. You can [customize the claims](https://django-oauth-toolkit.readthedocs.io/en/latest/oidc.html#customizing-the-oidc-responses) Django gives to synapse (optional): +
+ Code sample + + ```python + class CustomOAuth2Validator(OAuth2Validator): + + def get_additional_claims(self, request): + return { + "sub": request.user.email, + "email": request.user.email, + "first_name": request.user.first_name, + "last_name": request.user.last_name, + } + ``` +
+Your synapse config is then: + +```yaml +oidc_providers: + - idp_id: django_example + idp_name: "Django Example" + issuer: "https://example.com/o/" + client_id: "your-client-id" # CHANGE ME + client_secret: "your-client-secret" # CHANGE ME + scopes: ["openid"] + user_profile_method: "userinfo_endpoint" # needed because oauth-toolkit does not include user information in the authorization response + user_mapping_provider: + config: + localpart_template: "{{ user.email.split('@')[0] }}" + display_name_template: "{{ user.first_name }} {{ user.last_name }}" + email_template: "{{ user.email }}" +``` From 15db8b7c7f13f33ca49104183e0642892c3b83f1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 24 Aug 2021 10:17:51 +0100 Subject: [PATCH 10/28] Correctly initialise the `synapse_user_logins` metric. (#10677) Fix a bug where the prometheus metrics for SSO logins wouldn't be initialised until the first user logged in with a given auth provider. --- changelog.d/10677.bugfix | 1 + synapse/handlers/register.py | 18 ++++++++++++++++++ synapse/handlers/sso.py | 2 ++ synapse/rest/client/login.py | 29 +++++++++++++++++++++++------ 4 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10677.bugfix diff --git a/changelog.d/10677.bugfix b/changelog.d/10677.bugfix new file mode 100644 index 000000000000..9964afaaeea1 --- /dev/null +++ b/changelog.d/10677.bugfix @@ -0,0 +1 @@ +Fix a bug which caused the `synapse_user_logins_total` Prometheus metric not to be correctly initialised on restart. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8cf614136ebb..0ed59d757bf7 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -56,6 +56,22 @@ ) +def init_counters_for_auth_provider(auth_provider_id: str) -> None: + """Ensure the prometheus counters for the given auth provider are initialised + + This fixes a problem where the counters are not reported for a given auth provider + until the user first logs in/registers. + """ + for is_guest in (True, False): + login_counter.labels(guest=is_guest, auth_provider=auth_provider_id) + for shadow_banned in (True, False): + registration_counter.labels( + guest=is_guest, + shadow_banned=shadow_banned, + auth_provider=auth_provider_id, + ) + + class LoginDict(TypedDict): device_id: str access_token: str @@ -96,6 +112,8 @@ def __init__(self, hs: "HomeServer"): self.session_lifetime = hs.config.session_lifetime self.access_token_lifetime = hs.config.access_token_lifetime + init_counters_for_auth_provider("") + async def check_username( self, localpart: str, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 1b855a685c4a..0e6ebb574ecf 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -37,6 +37,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect @@ -213,6 +214,7 @@ def register_identity_provider(self, p: SsoIdentityProvider): p_id = p.idp_id assert p_id not in self._identity_providers self._identity_providers[p_id] = p + init_counters_for_auth_provider(p_id) def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]: """Get the configured identity providers""" diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 0c8d8967b7ee..11d07776b2ff 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -104,6 +104,12 @@ def __init__(self, hs: "HomeServer"): burst_count=self.hs.config.rc_login_account.burst_count, ) + # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. + # The reason for this is to ensure that the auth_provider_ids are registered + # with SsoHandler, which in turn ensures that the login/registration prometheus + # counters are initialised for the auth_provider_ids. + _load_sso_handlers(hs) + def on_GET(self, request: SynapseRequest): flows = [] if self.jwt_enabled: @@ -499,12 +505,7 @@ class SsoRedirectServlet(RestServlet): def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they # register themselves with the main SSOHandler. - if hs.config.cas_enabled: - hs.get_cas_handler() - if hs.config.saml2_enabled: - hs.get_saml_handler() - if hs.config.oidc_enabled: - hs.get_oidc_handler() + _load_sso_handlers(hs) self._sso_handler = hs.get_sso_handler() self._msc2858_enabled = hs.config.experimental.msc2858_enabled self._public_baseurl = hs.config.public_baseurl @@ -598,3 +599,19 @@ def register_servlets(hs, http_server): SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: CasTicketServlet(hs).register(http_server) + + +def _load_sso_handlers(hs: "HomeServer"): + """Ensure that the SSO handlers are loaded, if they are enabled by configuration. + + This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves + with the main SsoHandler. + + It's safe to call this multiple times. + """ + if hs.config.cas.cas_enabled: + hs.get_cas_handler() + if hs.config.saml2.saml2_enabled: + hs.get_saml_handler() + if hs.config.oidc.oidc_enabled: + hs.get_oidc_handler() From d12ba52f178982ecb47207471bee14472f9597b6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 24 Aug 2021 08:14:03 -0400 Subject: [PATCH 11/28] Persist room hierarchy pagination sessions to the database. (#10613) --- changelog.d/10613.feature | 1 + mypy.ini | 1 + synapse/app/generic_worker.py | 2 + synapse/handlers/room_summary.py | 76 ++++----- synapse/storage/databases/main/__init__.py | 2 + synapse/storage/databases/main/session.py | 145 ++++++++++++++++++ .../schema/main/delta/62/02session_store.sql | 23 +++ 7 files changed, 212 insertions(+), 38 deletions(-) create mode 100644 changelog.d/10613.feature create mode 100644 synapse/storage/databases/main/session.py create mode 100644 synapse/storage/schema/main/delta/62/02session_store.sql diff --git a/changelog.d/10613.feature b/changelog.d/10613.feature new file mode 100644 index 000000000000..ffc4e4289cfa --- /dev/null +++ b/changelog.d/10613.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/mypy.ini b/mypy.ini index b17872211e72..745e6b78eb62 100644 --- a/mypy.ini +++ b/mypy.ini @@ -57,6 +57,7 @@ files = synapse/storage/databases/main/keys.py, synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/registration.py, + synapse/storage/databases/main/session.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index fd2626dbe1db..9b71dd75e6d3 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -118,6 +118,7 @@ from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.search import SearchStore +from synapse.storage.databases.main.session import SessionStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore @@ -253,6 +254,7 @@ class GenericWorkerSlavedStore( SearchStore, TransactionWorkerStore, LockStore, + SessionStore, BaseSlavedStore, ): pass diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index ac6cfc0da915..906985c754d9 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -28,12 +28,11 @@ Membership, RoomTypes, ) -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache -from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -76,6 +75,9 @@ class _PaginationSession: class RoomSummaryHandler: + # A unique key used for pagination sessions for the room hierarchy endpoint. + _PAGINATION_SESSION_TYPE = "room_hierarchy_pagination" + # The time a pagination session remains valid for. _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 @@ -87,12 +89,6 @@ def __init__(self, hs: "HomeServer"): self._server_name = hs.hostname self._federation_client = hs.get_federation_client() - # A map of query information to the current pagination state. - # - # TODO Allow for multiple workers to share this data. - # TODO Expire pagination tokens. - self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} - # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. self._pagination_response_cache: ResponseCache[ @@ -102,21 +98,6 @@ def __init__(self, hs: "HomeServer"): "get_room_hierarchy", ) - def _expire_pagination_sessions(self): - """Expire pagination session which are old.""" - expire_before = ( - self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS - ) - to_expire = [] - - for key, value in self._pagination_sessions.items(): - if value.creation_time_ms < expire_before: - to_expire.append(key) - - for key in to_expire: - logger.debug("Expiring pagination session id %s", key) - del self._pagination_sessions[key] - async def get_space_summary( self, requester: str, @@ -327,18 +308,29 @@ async def _get_room_hierarchy( # If this is continuing a previous session, pull the persisted data. if from_token: - self._expire_pagination_sessions() + try: + pagination_session = await self._store.get_session( + session_type=self._PAGINATION_SESSION_TYPE, + session_id=from_token, + ) + except StoreError: + raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, from_token - ) - if pagination_key not in self._pagination_sessions: + # If the requester, room ID, suggested-only, or max depth were modified + # the session is invalid. + if ( + requester != pagination_session["requester"] + or requested_room_id != pagination_session["room_id"] + or suggested_only != pagination_session["suggested_only"] + or max_depth != pagination_session["max_depth"] + ): raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) # Load the previous state. - pagination_session = self._pagination_sessions[pagination_key] - room_queue = pagination_session.room_queue - processed_rooms = pagination_session.processed_rooms + room_queue = [ + _RoomQueueEntry(*fields) for fields in pagination_session["room_queue"] + ] + processed_rooms = set(pagination_session["processed_rooms"]) else: # The queue of rooms to process, the next room is last on the stack. room_queue = [_RoomQueueEntry(requested_room_id, ())] @@ -456,13 +448,21 @@ async def _get_room_hierarchy( # If there's additional data, generate a pagination token (and persist state). if room_queue: - next_batch = random_string(24) - result["next_batch"] = next_batch - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, next_batch - ) - self._pagination_sessions[pagination_key] = _PaginationSession( - self._clock.time_msec(), room_queue, processed_rooms + result["next_batch"] = await self._store.create_session( + session_type=self._PAGINATION_SESSION_TYPE, + value={ + # Information which must be identical across pagination. + "requester": requester, + "room_id": requested_room_id, + "suggested_only": suggested_only, + "max_depth": max_depth, + # The stored state. + "room_queue": [ + attr.astuple(room_entry) for room_entry in room_queue + ], + "processed_rooms": list(processed_rooms), + }, + expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS, ) return result diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 01b918e12e10..00a644e8f71c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -63,6 +63,7 @@ from .room import RoomStore from .roommember import RoomMemberStore from .search import SearchStore +from .session import SessionStore from .signatures import SignatureStore from .state import StateStore from .stats import StatsStore @@ -121,6 +122,7 @@ class DataStore( ServerMetricsStore, EventForwardExtremitiesStore, LockStore, + SessionStore, ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py new file mode 100644 index 000000000000..172f27d109ad --- /dev/null +++ b/synapse/storage/databases/main/session.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +import synapse.util.stringutils as stringutils +from synapse.api.errors import StoreError +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.types import JsonDict +from synapse.util import json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class SessionStore(SQLBaseStore): + """ + A store for generic session data. + + Each type of session should provide a unique type (to separate sessions). + + Sessions are automatically removed when they expire. + """ + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + # Create a background job for culling expired sessions. + if hs.config.run_background_tasks: + self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000) + + async def create_session( + self, session_type: str, value: JsonDict, expiry_ms: int + ) -> str: + """ + Creates a new pagination session for the room hierarchy endpoint. + + Args: + session_type: The type for this session. + value: The value to store. + expiry_ms: How long before an item is evicted from the cache + in milliseconds. Default is 0, indicating items never get + evicted based on time. + + Returns: + The newly created session ID. + + Raises: + StoreError if a unique session ID cannot be generated. + """ + # autogen a session ID and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + session_id = stringutils.random_string(24) + + try: + await self.db_pool.simple_insert( + table="sessions", + values={ + "session_id": session_id, + "session_type": session_type, + "value": json_encoder.encode(value), + "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms, + }, + desc="create_session", + ) + + return session_id + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a session ID.") + + async def get_session(self, session_type: str, session_id: str) -> JsonDict: + """ + Retrieve data stored with create_session + + Args: + session_type: The type for this session. + session_id: The session ID returned from create_session. + + Raises: + StoreError if the session cannot be found. + """ + + def _get_session( + txn: LoggingTransaction, session_type: str, session_id: str, ts: int + ) -> JsonDict: + # This includes the expiry time since items are only periodically + # deleted, not upon expiry. + select_sql = """ + SELECT value FROM sessions WHERE + session_type = ? AND session_id = ? AND expiry_time_ms > ? + """ + txn.execute(select_sql, [session_type, session_id, ts]) + row = txn.fetchone() + + if not row: + raise StoreError(404, "No session") + + return db_to_json(row[0]) + + return await self.db_pool.runInteraction( + "get_session", + _get_session, + session_type, + session_id, + self._clock.time_msec(), + ) + + @wrap_as_background_process("delete_expired_sessions") + async def _delete_expired_sessions(self) -> None: + """Remove sessions with expiry dates that have passed.""" + + def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None: + sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?" + txn.execute(sql, (ts,)) + + await self.db_pool.runInteraction( + "delete_expired_sessions", + _delete_expired_sessions_txn, + self._clock.time_msec(), + ) diff --git a/synapse/storage/schema/main/delta/62/02session_store.sql b/synapse/storage/schema/main/delta/62/02session_store.sql new file mode 100644 index 000000000000..535fb34c109c --- /dev/null +++ b/synapse/storage/schema/main/delta/62/02session_store.sql @@ -0,0 +1,23 @@ +/* + * Copyright 2021 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS sessions( + session_type TEXT NOT NULL, -- The unique key for this type of session. + session_id TEXT NOT NULL, -- The session ID passed to the client. + value TEXT NOT NULL, -- A JSON dictionary to persist. + expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds). + UNIQUE (session_type, session_id) +); From 6f77a3d433c683223024075e805f87bec3327036 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Aug 2021 15:31:55 +0100 Subject: [PATCH 12/28] 1.41.0 --- CHANGES.md | 9 +++++++++ changelog.d/10571.feature | 1 - debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 4 files changed, 16 insertions(+), 2 deletions(-) delete mode 100644 changelog.d/10571.feature diff --git a/CHANGES.md b/CHANGES.md index cad9423ebd1a..35456cded6d6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,12 @@ +Synapse 1.41.0 (2021-08-24) +=========================== + +Features +-------- + +- Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version for restricted rooms. ([\#10571](https://github.com/matrix-org/synapse/issues/10571)) + + Synapse 1.41.0rc1 (2021-08-18) ============================== diff --git a/changelog.d/10571.feature b/changelog.d/10571.feature deleted file mode 100644 index 0da318cd5b34..000000000000 --- a/changelog.d/10571.feature +++ /dev/null @@ -1 +0,0 @@ -Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version for restricted rooms. diff --git a/debian/changelog b/debian/changelog index 68f309b0b25b..4da4bc018cf0 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.41.0) stable; urgency=medium + + * New synapse release 1.41.0. + + -- Synapse Packaging team Tue, 24 Aug 2021 15:31:45 +0100 + matrix-synapse-py3 (1.41.0~rc1) stable; urgency=medium * New synapse release 1.41.0~rc1. diff --git a/synapse/__init__.py b/synapse/__init__.py index 6ada20a77f49..ef3770262e8f 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ except ImportError: pass -__version__ = "1.41.0rc1" +__version__ = "1.41.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when From f03cafb50c49a1569f1f99485f9cc42abfdc7b21 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Aug 2021 16:06:33 +0100 Subject: [PATCH 13/28] Update changelog --- CHANGES.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 35456cded6d6..f8da8771aa6e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,10 +1,15 @@ Synapse 1.41.0 (2021-08-24) =========================== +This release adds support for Debian 12 (Bookworm), but **removes support for Ubuntu 20.10 (Groovy Gorilla)**, which reached End of Life last month. + +Note that when using workers the `/_synapse/admin/v1/users/{userId}/media` must now be handled by media workers. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. + + Features -------- -- Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version for restricted rooms. ([\#10571](https://github.com/matrix-org/synapse/issues/10571)) +- Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version when creating restricted rooms. ([\#10571](https://github.com/matrix-org/synapse/issues/10571)) Synapse 1.41.0rc1 (2021-08-18) @@ -16,7 +21,7 @@ Features - Add `get_userinfo_by_id` method to ModuleApi. ([\#9581](https://github.com/matrix-org/synapse/issues/9581)) - Initial local support for [MSC3266](https://github.com/matrix-org/synapse/pull/10394), Room Summary over the unstable `/rooms/{roomIdOrAlias}/summary` API. ([\#10394](https://github.com/matrix-org/synapse/issues/10394)) - Experimental support for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288), sending `room_type` to the identity server for 3pid invites over the `/store-invite` API. ([\#10435](https://github.com/matrix-org/synapse/issues/10435)) -- Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. ([\#10475](https://github.com/matrix-org/synapse/issues/10475)) +- Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. ([\#10596](https://github.com/matrix-org/synapse/issues/10596)). ([\#10475](https://github.com/matrix-org/synapse/issues/10475)) - Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716)). ([\#10498](https://github.com/matrix-org/synapse/issues/10498)) - Add a configuration setting for the time a `/sync` response is cached for. ([\#10513](https://github.com/matrix-org/synapse/issues/10513)) - The default logging handler for new installations is now `PeriodicallyFlushingMemoryHandler`, a buffered logging handler which periodically flushes itself. ([\#10518](https://github.com/matrix-org/synapse/issues/10518)) @@ -38,7 +43,7 @@ Bugfixes - Add some clarification to the sample config file. Contributed by @Kentokamoto. ([\#10129](https://github.com/matrix-org/synapse/issues/10129)) - Fix a long-standing bug where protocols which are not implemented by any appservices were incorrectly returned via `GET /_matrix/client/r0/thirdparty/protocols`. ([\#10532](https://github.com/matrix-org/synapse/issues/10532)) - Fix exceptions in logs when failing to get remote room list. ([\#10541](https://github.com/matrix-org/synapse/issues/10541)) -- Fix longstanding bug which caused the user "status" to be reset when the user went offline. Contributed by @dklimpel. ([\#10550](https://github.com/matrix-org/synapse/issues/10550)) +- Fix longstanding bug which caused the user's presence "status message" to be reset when the user went offline. Contributed by @dklimpel. ([\#10550](https://github.com/matrix-org/synapse/issues/10550)) - Allow public rooms to be previewed in the spaces summary APIs from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10580](https://github.com/matrix-org/synapse/issues/10580)) - Fix a bug introduced in v1.37.1 where an error could occur in the asynchronous processing of PDUs when the queue was empty. ([\#10592](https://github.com/matrix-org/synapse/issues/10592)) - Fix errors on /sync when read receipt data is a string. Only affects homeservers with the experimental flag for [MSC2285](https://github.com/matrix-org/matrix-doc/pull/2285) enabled. Contributed by @SimonBrandner. ([\#10606](https://github.com/matrix-org/synapse/issues/10606)) @@ -49,7 +54,7 @@ Bugfixes Improved Documentation ---------------------- -- Add documentation for configuration a forward proxy. ([\#10443](https://github.com/matrix-org/synapse/issues/10443)) +- Add documentation for configuring a forward proxy. ([\#10443](https://github.com/matrix-org/synapse/issues/10443)) - Updated the reverse proxy documentation to highlight the homserver configuration that is needed to make Synapse aware that is is intentionally reverse proxied. ([\#10551](https://github.com/matrix-org/synapse/issues/10551)) - Update CONTRIBUTING.md to fix index links and the instructions for SyTest in docker. ([\#10599](https://github.com/matrix-org/synapse/issues/10599)) From 7367473f965fed1160cb8633de341c5833e5b662 Mon Sep 17 00:00:00 2001 From: Sean Date: Wed, 25 Aug 2021 10:51:08 +0100 Subject: [PATCH 14/28] Fix error when selecting between thumbnails with the same quality (#10684) Fixes #10318 --- changelog.d/10684.bugfix | 1 + synapse/rest/media/v1/thumbnail_resource.py | 26 +++++++++----- tests/rest/media/v1/test_media_storage.py | 39 ++++++++++++++++++++- 3 files changed, 56 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10684.bugfix diff --git a/changelog.d/10684.bugfix b/changelog.d/10684.bugfix new file mode 100644 index 000000000000..311b17601a3e --- /dev/null +++ b/changelog.d/10684.bugfix @@ -0,0 +1 @@ +Fix long-standing issue which caused an error when a thumbnail is requested and there are multiple thumbnails with the same quality rating. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index a029d426f0b6..12bd745cb21c 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -15,7 +15,7 @@ import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from twisted.web.server import Request @@ -414,9 +414,9 @@ def _select_thumbnail( if desired_method == "crop": # Thumbnails that match equal or larger sizes of desired width/height. - crop_info_list = [] + crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = [] # Other thumbnails. - crop_info_list2 = [] + crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = [] for info in thumbnail_infos: # Skip thumbnails generated with different methods. if info["thumbnail_method"] != "crop": @@ -451,15 +451,19 @@ def _select_thumbnail( info, ) ) + # Pick the most appropriate thumbnail. Some values of `desired_width` and + # `desired_height` may result in a tie, in which case we avoid comparing on + # the thumbnail info dictionary and pick the thumbnail that appears earlier + # in the list of candidates. if crop_info_list: - thumbnail_info = min(crop_info_list)[-1] + thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1] elif crop_info_list2: - thumbnail_info = min(crop_info_list2)[-1] + thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1] elif desired_method == "scale": # Thumbnails that match equal or larger sizes of desired width/height. - info_list = [] + info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = [] # Other thumbnails. - info_list2 = [] + info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = [] for info in thumbnail_infos: # Skip thumbnails generated with different methods. @@ -477,10 +481,14 @@ def _select_thumbnail( info_list2.append( (size_quality, type_quality, length_quality, info) ) + # Pick the most appropriate thumbnail. Some values of `desired_width` and + # `desired_height` may result in a tie, in which case we avoid comparing on + # the thumbnail info dictionary and pick the thumbnail that appears earlier + # in the list of candidates. if info_list: - thumbnail_info = min(info_list)[-1] + thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1] elif info_list2: - thumbnail_info = min(info_list2)[-1] + thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1] if thumbnail_info: return FileInfo( diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 6085444b9da8..2f7eebfe6931 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -21,7 +21,7 @@ from urllib import parse import attr -from parameterized import parameterized_class +from parameterized import parameterized, parameterized_class from PIL import Image as Image from twisted.internet import defer @@ -473,6 +473,43 @@ def _test_thumbnail(self, method, expected_body, expected_found): }, ) + @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) + def test_same_quality(self, method, desired_size): + """Test that choosing between thumbnails with the same quality rating succeeds. + + We are not particular about which thumbnail is chosen.""" + self.assertIsNotNone( + self.thumbnail_resource._select_thumbnail( + desired_width=desired_size, + desired_height=desired_size, + desired_method=method, + desired_type=self.test_image.content_type, + # Provide two identical thumbnails which are guaranteed to have the same + # quality rating. + thumbnail_infos=[ + { + "thumbnail_width": 32, + "thumbnail_height": 32, + "thumbnail_method": method, + "thumbnail_type": self.test_image.content_type, + "thumbnail_length": 256, + "filesystem_id": f"thumbnail1{self.test_image.extension}", + }, + { + "thumbnail_width": 32, + "thumbnail_height": 32, + "thumbnail_method": method, + "thumbnail_type": self.test_image.content_type, + "thumbnail_length": 256, + "filesystem_id": f"thumbnail2{self.test_image.extension}", + }, + ], + file_id=f"image{self.test_image.extension}", + url_cache=None, + server_name=None, + ) + ) + def test_x_robots_tag_header(self): """ Tests that the `X-Robots-Tag` header is present, which informs web crawlers From 882539e423d3eaad703cdee80582f12a27b34d58 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 25 Aug 2021 10:18:23 -0400 Subject: [PATCH 15/28] Ensure the base Docker image is rebuilt when running complement with workers. (#10686) We now always rebuild the matrixdotorg/synapse image, then build the matrixdotorg/synapse-workers image on top of it. --- changelog.d/10686.misc | 1 + scripts-dev/complement.sh | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) create mode 100644 changelog.d/10686.misc diff --git a/changelog.d/10686.misc b/changelog.d/10686.misc new file mode 100644 index 000000000000..b76908d74ee7 --- /dev/null +++ b/changelog.d/10686.misc @@ -0,0 +1 @@ +Update `complement.sh` to rebuild the base Docker image when run with workers. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 5d0ef8dd3a73..89af7a4fde89 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -35,25 +35,25 @@ if [[ -z "$COMPLEMENT_DIR" ]]; then echo "Checkout available at 'complement-master'" fi +# Build the base Synapse image from the local checkout +docker build -t matrixdotorg/synapse -f "docker/Dockerfile" . + # If we're using workers, modify the docker files slightly. if [[ -n "$WORKERS" ]]; then - BASE_IMAGE=matrixdotorg/synapse-workers - BASE_DOCKERFILE=docker/Dockerfile-workers + # Build the workers docker image (from the base Synapse image). + docker build -t matrixdotorg/synapse-workers -f "docker/Dockerfile-workers" . + export COMPLEMENT_BASE_IMAGE=complement-synapse-workers COMPLEMENT_DOCKERFILE=SynapseWorkers.Dockerfile # And provide some more configuration to complement. export COMPLEMENT_CA=true export COMPLEMENT_VERSION_CHECK_ITERATIONS=500 else - BASE_IMAGE=matrixdotorg/synapse - BASE_DOCKERFILE=docker/Dockerfile export COMPLEMENT_BASE_IMAGE=complement-synapse COMPLEMENT_DOCKERFILE=Synapse.Dockerfile fi -# Build the base Synapse image from the local checkout -docker build -t $BASE_IMAGE -f "$BASE_DOCKERFILE" . -# Build the Synapse monolith image from Complement, based on the above image we just built +# Build the Complement image from the Synapse image we just built. docker build -t $COMPLEMENT_BASE_IMAGE -f "$COMPLEMENT_DIR/dockerfiles/$COMPLEMENT_DOCKERFILE" "$COMPLEMENT_DIR/dockerfiles" cd "$COMPLEMENT_DIR" From b45cc1530b1438b8bfd9c09f179c7338e85ac083 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 25 Aug 2021 17:00:44 +0100 Subject: [PATCH 16/28] Make a note to leave a summary when one is bumping the schema version (#10621) I found this easy to miss (and evidently, it looks like it was missed for schema version 62). --- changelog.d/10621.misc | 1 + synapse/storage/schema/__init__.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/10621.misc diff --git a/changelog.d/10621.misc b/changelog.d/10621.misc new file mode 100644 index 000000000000..b8de2e1911af --- /dev/null +++ b/changelog.d/10621.misc @@ -0,0 +1 @@ +Add a comment asking developers to leave a reason when bumping the database schema version. \ No newline at end of file diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index a5bc0ee8a560..af9cc69949c3 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# When updating these values, please leave a short summary of the changes below. + SCHEMA_VERSION = 63 """Represents the expectations made by the codebase about the database schema From 5548fe097881b543cba37c7cda27ff7efe55025d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Aug 2021 07:16:53 -0400 Subject: [PATCH 17/28] Cache the result of fetching the room hierarchy over federation. (#10647) --- changelog.d/10647.misc | 1 + synapse/federation/federation_client.py | 106 +++++++++++++++--------- 2 files changed, 67 insertions(+), 40 deletions(-) create mode 100644 changelog.d/10647.misc diff --git a/changelog.d/10647.misc b/changelog.d/10647.misc new file mode 100644 index 000000000000..4407a9030d55 --- /dev/null +++ b/changelog.d/10647.misc @@ -0,0 +1 @@ +Improve the performance of the `/hierarchy` API (from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)) by caching responses received over federation. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 44d9e8a5c734..1416abd0fba3 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -111,6 +111,23 @@ def __init__(self, hs: "HomeServer"): reset_expiry_on_get=False, ) + # A cache for fetching the room hierarchy over federation. + # + # Some stale data over federation is OK, but must be refreshed + # periodically since the local server is in the room. + # + # It is a map of (room ID, suggested-only) -> the response of + # get_room_hierarchy. + self._get_room_hierarchy_cache: ExpiringCache[ + Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]] + ] = ExpiringCache( + cache_name="get_room_hierarchy_cache", + clock=self._clock, + max_len=1000, + expiry_ms=5 * 60 * 1000, + reset_expiry_on_get=False, + ) + def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" now = self._clock.time_msec() @@ -1324,6 +1341,10 @@ async def get_room_hierarchy( remote servers """ + cached_result = self._get_room_hierarchy_cache.get((room_id, suggested_only)) + if cached_result: + return cached_result + async def send_request( destination: str, ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: @@ -1370,58 +1391,63 @@ async def send_request( return room, children, inaccessible_children try: - return await self._try_destination_list( + result = await self._try_destination_list( "fetch room hierarchy", destinations, send_request, failover_on_unknown_endpoint=True, ) except SynapseError as e: + # If an unexpected error occurred, re-raise it. + if e.code != 502: + raise + # Fallback to the old federation API and translate the results if # no servers implement the new API. # # The algorithm below is a bit inefficient as it only attempts to - # get information for the requested room, but the legacy API may + # parse information for the requested room, but the legacy API may # return additional layers. - if e.code == 502: - legacy_result = await self.get_space_summary( - destinations, - room_id, - suggested_only, - max_rooms_per_space=None, - exclude_rooms=[], - ) + legacy_result = await self.get_space_summary( + destinations, + room_id, + suggested_only, + max_rooms_per_space=None, + exclude_rooms=[], + ) - # Find the requested room in the response (and remove it). - for _i, room in enumerate(legacy_result.rooms): - if room.get("room_id") == room_id: - break - else: - # The requested room was not returned, nothing we can do. - raise - requested_room = legacy_result.rooms.pop(_i) - - # Find any children events of the requested room. - children_events = [] - children_room_ids = set() - for event in legacy_result.events: - if event.room_id == room_id: - children_events.append(event.data) - children_room_ids.add(event.state_key) - # And add them under the requested room. - requested_room["children_state"] = children_events - - # Find the children rooms. - children = [] - for room in legacy_result.rooms: - if room.get("room_id") in children_room_ids: - children.append(room) - - # It isn't clear from the response whether some of the rooms are - # not accessible. - return requested_room, children, () - - raise + # Find the requested room in the response (and remove it). + for _i, room in enumerate(legacy_result.rooms): + if room.get("room_id") == room_id: + break + else: + # The requested room was not returned, nothing we can do. + raise + requested_room = legacy_result.rooms.pop(_i) + + # Find any children events of the requested room. + children_events = [] + children_room_ids = set() + for event in legacy_result.events: + if event.room_id == room_id: + children_events.append(event.data) + children_room_ids.add(event.state_key) + # And add them under the requested room. + requested_room["children_state"] = children_events + + # Find the children rooms. + children = [] + for room in legacy_result.rooms: + if room.get("room_id") in children_room_ids: + children.append(room) + + # It isn't clear from the response whether some of the rooms are + # not accessible. + result = (requested_room, children, ()) + + # Cache the result to avoid fetching data over federation every time. + self._get_room_hierarchy_cache[(room_id, suggested_only)] = result + return result @attr.s(frozen=True, slots=True, auto_attribs=True) From 1aa0dad02187c3b972187f5952cfbce336b0ca5c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Aug 2021 07:53:52 -0400 Subject: [PATCH 18/28] Additional type hints for REST servlets (part 2). (#10674) Applies the changes from #10665 to additional modules. --- changelog.d/10674.misc | 1 + synapse/handlers/presence.py | 5 +++ synapse/rest/client/auth.py | 11 +++-- synapse/rest/client/devices.py | 48 ++++++++++++---------- synapse/rest/client/events.py | 38 ++++++++++------- synapse/rest/client/filter.py | 26 ++++++++---- synapse/rest/client/groups.py | 3 +- synapse/rest/client/initial_sync.py | 16 ++++++-- synapse/rest/client/keys.py | 57 ++++++++++---------------- synapse/rest/client/knock.py | 3 +- synapse/rest/client/login.py | 21 ++++------ synapse/rest/client/logout.py | 17 +++++--- synapse/rest/client/notifications.py | 13 ++++-- synapse/rest/client/openid.py | 16 ++++++-- synapse/rest/client/password_policy.py | 18 ++++---- synapse/rest/client/presence.py | 24 +++++++---- synapse/rest/client/profile.py | 37 ++++++++++++----- 17 files changed, 216 insertions(+), 138 deletions(-) create mode 100644 changelog.d/10674.misc diff --git a/changelog.d/10674.misc b/changelog.d/10674.misc new file mode 100644 index 000000000000..39a37b90b1b3 --- /dev/null +++ b/changelog.d/10674.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 7ca14e1d8473..4418d63df7f9 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -353,6 +353,11 @@ async def send_full_presence_to_users(self, user_ids: Collection[str]): # otherwise would not do). await self.set_state(UserID.from_string(user_id), state, force_notify=True) + async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: + raise NotImplementedError( + "Attempting to check presence on a non-presence worker." + ) + class _NullContextManager(ContextManager[None]): """A context manager which does nothing.""" diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 91800c02784c..df8cc4ac7ab5 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -15,11 +15,14 @@ import logging from typing import TYPE_CHECKING +from twisted.web.server import Request + from synapse.api.constants import LoginType from synapse.api.errors import LoginError, SynapseError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.http.server import respond_with_html +from synapse.http.server import HttpServer, respond_with_html from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest from ._base import client_patterns @@ -49,7 +52,7 @@ def __init__(self, hs: "HomeServer"): self.registration_token_template = hs.config.registration_token_template self.success_template = hs.config.fallback_success_template - async def on_GET(self, request, stagetype): + async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") @@ -88,7 +91,7 @@ async def on_GET(self, request, stagetype): respond_with_html(request, 200, html) return None - async def on_POST(self, request, stagetype): + async def on_POST(self, request: Request, stagetype: str) -> None: session = parse_string(request, "session") if not session: @@ -172,5 +175,5 @@ async def on_POST(self, request, stagetype): return None -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8b9674db064a..25bc3c8f477b 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -14,34 +14,36 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api import errors +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class DevicesRestServlet(RestServlet): PATTERNS = client_patterns("/devices$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) devices = await self.device_handler.get_devices_by_user( requester.user.to_string() @@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = client_patterns("/delete_devices") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -65,7 +67,7 @@ def __init__(self, hs): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) try: @@ -100,18 +102,16 @@ async def on_POST(self, request): class DeviceRestServlet(RestServlet): PATTERNS = client_patterns("/devices/(?P[^/]*)$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() - async def on_GET(self, request, device_id): + async def on_GET( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) device = await self.device_handler.get_device( requester.user.to_string(), device_id @@ -119,7 +119,9 @@ async def on_GET(self, request, device_id): return 200, device @interactive_auth_handler - async def on_DELETE(self, request, device_id): + async def on_DELETE( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) try: @@ -146,7 +148,9 @@ async def on_DELETE(self, request, device_id): await self.device_handler.delete_device(requester.user.to_string(), device_id) return 200, {} - async def on_PUT(self, request, device_id): + async def on_PUT( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_json_object_from_request(request) @@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet): PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - async def on_GET(self, request: SynapseRequest): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) dehydrated_device = await self.device_handler.get_dehydrated_device( requester.user.to_string() @@ -211,7 +215,7 @@ async def on_GET(self, request: SynapseRequest): else: raise errors.NotFoundError("No dehydrated device available") - async def on_PUT(self, request: SynapseRequest): + async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: submission = parse_json_object_from_request(request) requester = await self.auth.get_user_by_req(request) @@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet): "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - async def on_POST(self, request: SynapseRequest): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) submission = parse_json_object_from_request(request) @@ -292,7 +296,7 @@ async def on_POST(self, request: SynapseRequest): return (200, result) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 52bb579cfd40..13b72a045a4a 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -14,11 +14,18 @@ """This module contains REST servlets to do with event streaming, /events.""" import logging +from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest - room_id = None + args: Dict[bytes, List[bytes]] = request.args # type: ignore if is_guest: - if b"room_id" not in request.args: + if b"room_id" not in args: raise SynapseError(400, "Guest users must specify room_id param") - if b"room_id" in request.args: - room_id = request.args[b"room_id"][0].decode("ascii") + room_id = parse_string(request, "room_id") pagin_config = await PaginationConfig.from_request(self.store, request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if b"timeout" in request.args: + if b"timeout" in args: try: - timeout = int(request.args[b"timeout"][0]) + timeout = int(args[b"timeout"][0]) except ValueError: raise SynapseError(400, "timeout must be in milliseconds.") - as_client_event = b"raw" not in request.args + as_client_event = b"raw" not in args chunk = await self.event_stream_handler.get_stream( requester.user.to_string(), @@ -70,25 +76,27 @@ async def on_GET(self, request): class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self.auth = hs.get_auth() self._event_serializer = hs.get_event_client_serializer() - async def on_GET(self, request, event_id): + async def on_GET( + self, request: SynapseRequest, event_id: str + ) -> Tuple[int, Union[str, JsonDict]]: requester = await self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id) time_now = self.clock.time_msec() if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event + result = await self._event_serializer.serialize_event(event, time_now) + return 200, result else: return 404, "Event not found." -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EventStreamRestServlet(hs).register(http_server) EventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index 411667a9c8d9..6ed60c74181f 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -13,26 +13,34 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID from ._base import client_patterns, set_timeline_upper_limit +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class GetFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() - async def on_GET(self, request, user_id, filter_id): + async def on_GET( + self, request: SynapseRequest, user_id: str, filter_id: str + ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -43,13 +51,13 @@ async def on_GET(self, request, user_id, filter_id): raise AuthError(403, "Can only get filters for local users") try: - filter_id = int(filter_id) + filter_id_int = int(filter_id) except Exception: raise SynapseError(400, "Invalid filter_id") try: filter_collection = await self.filtering.get_user_filter( - user_localpart=target_user.localpart, filter_id=filter_id + user_localpart=target_user.localpart, filter_id=filter_id_int ) except StoreError as e: if e.code != 404: @@ -62,13 +70,15 @@ async def on_GET(self, request, user_id, filter_id): class CreateFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -89,6 +99,6 @@ async def on_POST(self, request, user_id): return 200, {"filter_id": str(filter_id)} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: GetFilterRestServlet(hs).register(http_server) CreateFilterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index 6285680c00c8..c3667ff8aaad 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -26,6 +26,7 @@ ) from synapse.api.errors import Codes, SynapseError from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -930,7 +931,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: return 200, result -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: GroupServlet(hs).register(http_server) GroupSummaryServlet(hs).register(http_server) GroupInvitedUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index 12ba0e91dbd1..49b1037b28d4 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -12,25 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Dict, List, Tuple +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer # TODO: Needs unit testing class InitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/initialSync$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - as_client_event = b"raw" not in request.args + args: Dict[bytes, List[bytes]] = request.args # type: ignore + as_client_event = b"raw" not in args pagination_config = await PaginationConfig.from_request(self.store, request) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( @@ -43,5 +51,5 @@ async def on_GET(self, request): return 200, content -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: InitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 012491f59736..7281b2ee2912 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -15,20 +15,25 @@ # limitations under the License. import logging -from typing import Any +from typing import TYPE_CHECKING, Any, Optional, Tuple from synapse.api.errors import InvalidAPICallError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.types import StreamToken +from synapse.types import JsonDict, StreamToken from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -60,18 +65,16 @@ class KeyUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = hs.get_device_handler() @trace(opname="upload_keys") - async def on_POST(self, request, device_id): + async def on_POST( + self, request: SynapseRequest, device_id: Optional[str] + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -149,16 +152,12 @@ class KeyQueryServlet(RestServlet): PATTERNS = client_patterns("/keys/query$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() device_id = requester.device_id @@ -195,17 +194,13 @@ class KeyChangesServlet(RestServlet): PATTERNS = client_patterns("/keys/changes$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) from_token_string = parse_string(request, "from", required=True) @@ -245,12 +240,12 @@ class OneTimeKeyServlet(RestServlet): PATTERNS = client_patterns("/keys/claim$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) @@ -269,11 +264,7 @@ class SigningKeyUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -281,7 +272,7 @@ def __init__(self, hs): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -329,16 +320,12 @@ class SignaturesUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/signatures/upload$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -349,7 +336,7 @@ async def on_POST(self, request): return 200, result -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KeyUploadServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index 7d1bc40658a6..68fb08d0ba0f 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -19,6 +19,7 @@ from synapse.api.constants import Membership from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -103,5 +104,5 @@ def on_PUT(self, request: Request, room_identifier: str, txn_id: str): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 11d07776b2ff..4be502a77b04 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -1,4 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple from typing_extensions import TypedDict @@ -110,7 +110,7 @@ def __init__(self, hs: "HomeServer"): # counters are initialised for the auth_provider_ids. _load_sso_handlers(hs) - def on_GET(self, request: SynapseRequest): + def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) @@ -157,7 +157,7 @@ def on_GET(self, request: SynapseRequest): return 200, {"flows": flows} - async def on_POST(self, request: SynapseRequest): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: login_submission = parse_json_object_from_request(request) if self._msc2918_enabled: @@ -217,7 +217,7 @@ async def _do_appservice_login( login_submission: JsonDict, appservice: ApplicationService, should_issue_refresh_token: bool = False, - ): + ) -> LoginResponse: identifier = login_submission.get("identifier") logger.info("Got appservice login request with identifier: %r", identifier) @@ -467,10 +467,7 @@ def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self.access_token_lifetime = hs.config.access_token_lifetime - async def on_POST( - self, - request: SynapseRequest, - ): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) assert_params_in_dict(refresh_submission, ["refresh_token"]) @@ -570,7 +567,7 @@ async def on_GET( class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self._cas_handler = hs.get_cas_handler() @@ -592,7 +589,7 @@ async def on_GET(self, request: SynapseRequest) -> None: ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) if hs.config.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) @@ -601,7 +598,7 @@ def register_servlets(hs, http_server): CasTicketServlet(hs).register(http_server) -def _load_sso_handlers(hs: "HomeServer"): +def _load_sso_handlers(hs: "HomeServer") -> None: """Ensure that the SSO handlers are loaded, if they are enabled by configuration. This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 6055cac2bd0a..193a6951b91b 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -13,9 +13,16 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -23,13 +30,13 @@ class LogoutRestServlet(RestServlet): PATTERNS = client_patterns("/logout$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) if requester.device_id is None: @@ -48,13 +55,13 @@ async def on_POST(self, request): class LogoutAllRestServlet(RestServlet): PATTERNS = client_patterns("/logout/all$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() @@ -67,6 +74,6 @@ async def on_POST(self, request): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LogoutRestServlet(hs).register(http_server) LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 0ede643c2d91..d1d8a984c630 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -13,26 +13,33 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.events.utils import format_event_for_client_v2_without_room_id +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): PATTERNS = client_patterns("/notifications$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -87,5 +94,5 @@ async def on_GET(self, request): return 200, {"notifications": returned_push_actions, "next_token": next_token} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: NotificationsServlet(hs).register(http_server) diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index e8d2673819cb..4dda6dce4ba1 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from synapse.util.stringutils import random_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -58,14 +64,16 @@ class IdTokenServlet(RestServlet): EXPIRES_MS = 3600 * 1000 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() self.server_name = hs.config.server_name - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot request tokens for other users.") @@ -90,5 +98,5 @@ async def on_POST(self, request, user_id): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: IdTokenServlet(hs).register(http_server) diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py index a83927aee641..6d64efb1651d 100644 --- a/synapse/rest/client/password_policy.py +++ b/synapse/rest/client/password_policy.py @@ -13,28 +13,32 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple +from twisted.web.server import Request + +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class PasswordPolicyServlet(RestServlet): PATTERNS = client_patterns("/password_policy$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled - def on_GET(self, request): + def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.enabled or not self.policy: return (200, {}) @@ -53,5 +57,5 @@ def on_GET(self, request): return (200, policy) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py index 6c27e5faf986..94dd4fe2f45c 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py @@ -15,12 +15,18 @@ """ This module contains REST servlets to do with presence: /presence/ """ import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -28,7 +34,7 @@ class PresenceStatusRestServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.presence_handler = hs.get_presence_handler() @@ -37,7 +43,9 @@ def __init__(self, hs): self._use_presence = hs.config.server.use_presence - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) @@ -53,13 +61,15 @@ async def on_GET(self, request, user_id): raise AuthError(403, "You are not allowed to see their presence.") state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state( + result = format_user_presence_state( state, self.clock.time_msec(), include_user_id=False ) - return 200, state + return 200, result - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) @@ -91,5 +101,5 @@ async def on_PUT(self, request, user_id): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PresenceStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index 5463ed2c4f85..d0f20de569c0 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -14,22 +14,31 @@ """ This module contains REST servlets to do with profile: /profile/ """ +from typing import TYPE_CHECKING, Tuple + from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer class ProfileDisplaynameRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester_user = None if self.hs.config.require_auth_for_profile_requests: @@ -48,7 +57,9 @@ async def on_GET(self, request, user_id): return 200, ret - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester.user) @@ -72,13 +83,15 @@ async def on_PUT(self, request, user_id): class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester_user = None if self.hs.config.require_auth_for_profile_requests: @@ -97,7 +110,9 @@ async def on_GET(self, request, user_id): return 200, ret - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester.user) @@ -120,13 +135,15 @@ async def on_PUT(self, request, user_id): class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester_user = None if self.hs.config.require_auth_for_profile_requests: @@ -149,7 +166,7 @@ async def on_GET(self, request, user_id): return 200, ret -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ProfileDisplaynameRestServlet(hs).register(http_server) ProfileAvatarURLRestServlet(hs).register(http_server) ProfileRestServlet(hs).register(http_server) From ad17fbd20eb2dd9fb10a3d02ab1b69e9a0d5b50c Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Thu, 26 Aug 2021 13:53:57 +0100 Subject: [PATCH 19/28] Remove pushers when deleting 3pid from account (#10581) When a user deletes an email from their account it will now also remove all pushers for that email and that user (even if these pushers were created by a different client) --- CHANGES.md | 2 + changelog.d/10581.bugfix | 1 + docs/upgrade.md | 5 ++ synapse/handlers/auth.py | 5 +- synapse/storage/databases/main/pusher.py | 72 +++++++++++++++++++ .../63/02delete_unlinked_email_pushers.sql | 20 ++++++ tests/push/test_email.py | 39 ++++++++++ 7 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10581.bugfix create mode 100644 synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql diff --git a/CHANGES.md b/CHANGES.md index f8da8771aa6e..24f3d53a6d31 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,5 @@ +Users will stop receiving message updates via email for addresses that were previously linked to their account + Synapse 1.41.0 (2021-08-24) =========================== diff --git a/changelog.d/10581.bugfix b/changelog.d/10581.bugfix new file mode 100644 index 000000000000..15c7da449734 --- /dev/null +++ b/changelog.d/10581.bugfix @@ -0,0 +1 @@ +Remove pushers when deleting a 3pid from an account. Pushers for old unlinked emails will also be deleted. \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index 6d4b8cb48edb..dcf0a7db5bf6 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -107,6 +107,11 @@ This may affect you if you make use of custom HTML templates for the The template is now provided an `error` variable if the authentication process failed. See the default templates linked above for an example. +# Upgrading to v1.42.0 + +## Removal of out-of-date email pushers +Users will stop receiving message updates via email for addresses that were +once, but not still, linked to their account. # Upgrading to v1.41.0 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 98d3d2d97faf..34725324a652 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1464,6 +1464,10 @@ async def delete_threepid( ) await self.store.user_delete_threepid(user_id, medium, address) + if medium == "email": + await self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id="m.email", pushkey=address, user_id=user_id + ) return result async def hash(self, password: str) -> str: @@ -1732,7 +1736,6 @@ def add_query_param_to_url(url: str, param_name: str, param: Any): @attr.s(slots=True) class MacaroonGenerator: - hs = attr.ib() def generate_guest_access_token(self, user_id: str) -> str: diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index b48fe086d4cc..e47caa212549 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -48,6 +48,11 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer" self._remove_stale_pushers, ) + self.db_pool.updates.register_background_update_handler( + "remove_deleted_email_pushers", + self._remove_deleted_email_pushers, + ) + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: """JSON-decode the data in the rows returned from the `pushers` table @@ -388,6 +393,73 @@ def _delete_pushers(txn) -> int: return number_deleted + async def _remove_deleted_email_pushers( + self, progress: dict, batch_size: int + ) -> int: + """A background update that deletes all pushers for deleted email addresses. + + In previous versions of synapse, when users deleted their email address, it didn't + also delete all the pushers for that email address. This background update removes + those to prevent unwanted emails. This should only need to be run once (when users + upgrade to v1.42.0 + + Args: + progress: dict used to store progress of this background update + batch_size: the maximum number of rows to retrieve in a single select query + + Returns: + The number of deleted rows + """ + + last_pusher = progress.get("last_pusher", 0) + + def _delete_pushers(txn) -> int: + + sql = """ + SELECT p.id, p.user_name, p.app_id, p.pushkey + FROM pushers AS p + LEFT JOIN user_threepids AS t + ON t.user_id = p.user_name + AND t.medium = 'email' + AND t.address = p.pushkey + WHERE t.user_id is NULL + AND p.app_id = 'm.email' + AND p.id > ? + ORDER BY p.id ASC + LIMIT ? + """ + + txn.execute(sql, (last_pusher, batch_size)) + + last = None + num_deleted = 0 + for row in txn: + last = row[0] + num_deleted += 1 + self.db_pool.simple_delete_txn( + txn, + "pushers", + {"user_name": row[1], "app_id": row[2], "pushkey": row[3]}, + ) + + if last is not None: + self.db_pool.updates._background_update_progress_txn( + txn, "remove_deleted_email_pushers", {"last_pusher": last} + ) + + return num_deleted + + number_deleted = await self.db_pool.runInteraction( + "_remove_deleted_email_pushers", _delete_pushers + ) + + if number_deleted < batch_size: + await self.db_pool.updates._end_background_update( + "remove_deleted_email_pushers" + ) + + return number_deleted + class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self) -> int: diff --git a/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql new file mode 100644 index 000000000000..611c4b95cf15 --- /dev/null +++ b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- We may not have deleted all pushers for emails that are no longer linked +-- to an account, so we set up a background job to delete them. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6302, 'remove_deleted_email_pushers', '{}'); diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e0a3342088d4..eea07485a017 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -125,6 +125,8 @@ def prepare(self, reactor, clock, hs): ) ) + self.auth_handler = hs.get_auth_handler() + def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated their email. @@ -305,6 +307,43 @@ def test_encrypted_message(self): # We should get emailed about that message self._check_for_mail() + def test_no_email_sent_after_removed(self): + # Create a simple room with two users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, + src=self.user_id, + tok=self.access_token, + targ=self.others[0].id, + ) + self.helper.join( + room=room, + user=self.others[0].id, + tok=self.others[0].token, + ) + + # The other user sends a single message. + self.helper.send(room, body="Hi!", tok=self.others[0].token) + + # We should get emailed about that message + self._check_for_mail() + + # disassociate the user's email address + self.get_success( + self.auth_handler.delete_threepid( + user_id=self.user_id, + medium="email", + address="a@example.com", + ) + ) + + # check that the pusher for that email address has been deleted + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 0) + def _check_for_mail(self): """Check that the user receives an email notification""" From 40f619eaa54d2391deccec473fc0f655c379e766 Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Thu, 26 Aug 2021 11:07:58 -0500 Subject: [PATCH 20/28] Validate new m.room.power_levels events (#10232) Signed-off-by: Aaron Raimist --- changelog.d/10232.bugfix | 1 + synapse/events/utils.py | 5 +- synapse/events/validator.py | 77 ++++++++++++++++++++++++- synapse/python_dependencies.py | 3 +- tests/rest/client/test_power_levels.py | 78 ++++++++++++++++++++++++++ 5 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10232.bugfix diff --git a/changelog.d/10232.bugfix b/changelog.d/10232.bugfix new file mode 100644 index 000000000000..7be72271e018 --- /dev/null +++ b/changelog.d/10232.bugfix @@ -0,0 +1 @@ +Validate new `m.room.power_levels` events. Contributed by @aaronraimist. \ No newline at end of file diff --git a/synapse/events/utils.py b/synapse/events/utils.py index b6da2f60af99..738a151cefd1 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -32,6 +32,9 @@ # the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" SPLIT_FIELD_REGEX = re.compile(r"(? EventBase: """Returns a pruned version of the given event, which removes all keys we @@ -505,7 +508,7 @@ def validate_canonicaljson(value: Any): * NaN, Infinity, -Infinity """ if isinstance(value, int): - if value <= -(2 ** 53) or 2 ** 53 <= value: + if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) elif isinstance(value, float): diff --git a/synapse/events/validator.py b/synapse/events/validator.py index fa6987d7cbac..33954b4f6217 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -11,16 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import collections.abc from typing import Union +import jsonschema + from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.events.utils import validate_canonicaljson +from synapse.events.utils import ( + CANONICALJSON_MAX_INT, + CANONICALJSON_MIN_INT, + validate_canonicaljson, +) from synapse.federation.federation_server import server_matches_acl_event from synapse.types import EventID, RoomID, UserID @@ -87,6 +93,29 @@ def validate_new(self, event: EventBase, config: HomeServerConfig): 400, "Can't create an ACL event that denies the local server" ) + if event.type == EventTypes.PowerLevels: + try: + jsonschema.validate( + instance=event.content, + schema=POWER_LEVELS_SCHEMA, + cls=plValidator, + ) + except jsonschema.ValidationError as e: + if e.path: + # example: "users_default": '0' is not of type 'integer' + message = '"' + e.path[-1] + '": ' + e.message # noqa: B306 + # jsonschema.ValidationError.message is a valid attribute + else: + # example: '0' is not of type 'integer' + message = e.message # noqa: B306 + # jsonschema.ValidationError.message is a valid attribute + + raise SynapseError( + code=400, + msg=message, + errcode=Codes.BAD_JSON, + ) + def _validate_retention(self, event: EventBase): """Checks that an event that defines the retention policy for a room respects the format enforced by the spec. @@ -185,3 +214,47 @@ def _ensure_strings(self, d, keys): def _ensure_state_event(self, event): if not event.is_state(): raise SynapseError(400, "'%s' must be state events" % (event.type,)) + + +POWER_LEVELS_SCHEMA = { + "type": "object", + "properties": { + "ban": {"$ref": "#/definitions/int"}, + "events": {"$ref": "#/definitions/objectOfInts"}, + "events_default": {"$ref": "#/definitions/int"}, + "invite": {"$ref": "#/definitions/int"}, + "kick": {"$ref": "#/definitions/int"}, + "notifications": {"$ref": "#/definitions/objectOfInts"}, + "redact": {"$ref": "#/definitions/int"}, + "state_default": {"$ref": "#/definitions/int"}, + "users": {"$ref": "#/definitions/objectOfInts"}, + "users_default": {"$ref": "#/definitions/int"}, + }, + "definitions": { + "int": { + "type": "integer", + "minimum": CANONICALJSON_MIN_INT, + "maximum": CANONICALJSON_MAX_INT, + }, + "objectOfInts": { + "type": "object", + "additionalProperties": {"$ref": "#/definitions/int"}, + }, + }, +} + + +def _create_power_level_validator(): + validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA) + + # by default jsonschema does not consider a frozendict to be an object so + # we need to use a custom type checker + # https://python-jsonschema.readthedocs.io/en/stable/validate/?highlight=object#validating-with-additional-types + type_checker = validator.TYPE_CHECKER.redefine( + "object", lambda checker, thing: isinstance(thing, collections.abc.Mapping) + ) + + return jsonschema.validators.extend(validator, type_checker=type_checker) + + +plValidator = _create_power_level_validator() diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index cdcbdd772b14..154e5b7028e9 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -48,7 +48,8 @@ # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. REQUIREMENTS = [ - "jsonschema>=2.5.1", + # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0 + "jsonschema>=3.0.0", "frozendict>=1", "unpaddedbase64>=1.1.0", "canonicaljson>=1.4.0", diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index 91d0762cb0ab..c0de4c93a806 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.errors import Codes +from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.rest import admin from synapse.rest.client import login, room, sync @@ -203,3 +205,79 @@ def test_admins_can_tombstone_room(self): tok=self.admin_access_token, expect_code=200, # expect success ) + + def test_cannot_set_string_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL "0" + room_power_levels["users"].update({self.user_user_id: "0"}) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) + + def test_cannot_set_unsafe_large_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL above the max safe integer + room_power_levels["users"].update( + {self.user_user_id: CANONICALJSON_MAX_INT + 1} + ) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) + + def test_cannot_set_unsafe_small_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL below the minimum safe integer + room_power_levels["users"].update( + {self.user_user_id: CANONICALJSON_MIN_INT - 1} + ) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) From 96715d763362a7027c39d571cfde3aa8b7b82fcf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 26 Aug 2021 18:34:57 +0100 Subject: [PATCH 21/28] Make `backfill` and `get_missing_events` use the same codepath (#10645) Given that backfill and get_missing_events are basically the same thing, it's somewhat crazy that we have entirely separate code paths for them. This makes backfill use the existing get_missing_events code, and then clears up all the unused code. --- changelog.d/10645.misc | 1 + synapse/handlers/federation.py | 273 +++--------------- .../storage/databases/main/purge_events.py | 1 + 3 files changed, 42 insertions(+), 233 deletions(-) create mode 100644 changelog.d/10645.misc diff --git a/changelog.d/10645.misc b/changelog.d/10645.misc new file mode 100644 index 000000000000..ac19263cd861 --- /dev/null +++ b/changelog.d/10645.misc @@ -0,0 +1 @@ +Make `backfill` and `get_missing_events` use the same codepath. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 246df43501bc..6fa2fc8f5284 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -65,6 +65,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.federation.federation_client import InvalidResponseError from synapse.handlers._base import BaseHandler from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( @@ -116,10 +117,6 @@ class _NewEventInfo: Attributes: event: the received event - state: the state at that event, according to /state_ids from a remote - homeserver. Only populated for backfilled events which are going to be a - new backwards extremity. - claimed_auth_event_map: a map of (type, state_key) => event for the event's claimed auth_events. @@ -134,7 +131,6 @@ class _NewEventInfo: """ event: EventBase - state: Optional[Sequence[EventBase]] claimed_auth_event_map: StateMap[EventBase] @@ -443,113 +439,7 @@ async def _get_missing_events_for_pdu( return logger.info("Got %d prev_events", len(missing_events)) - await self._process_pulled_events(origin, missing_events) - - async def _get_state_for_room( - self, - destination: str, - room_id: str, - event_id: str, - ) -> List[EventBase]: - """Requests all of the room state at a given event from a remote - homeserver. - - Will also fetch any missing events reported in the `auth_chain_ids` - section of `/state_ids`. - - Args: - destination: The remote homeserver to query for the state. - room_id: The id of the room we're interested in. - event_id: The id of the event we want the state at. - - Returns: - A list of events in the state, not including the event itself. - """ - ( - state_event_ids, - auth_event_ids, - ) = await self.federation_client.get_room_state_ids( - destination, room_id, event_id=event_id - ) - - # Fetch the state events from the DB, and check we have the auth events. - event_map = await self.store.get_events(state_event_ids, allow_rejected=True) - auth_events_in_store = await self.store.have_seen_events( - room_id, auth_event_ids - ) - - # Check for missing events. We handle state and auth event seperately, - # as we want to pull the state from the DB, but we don't for the auth - # events. (Note: we likely won't use the majority of the auth chain, and - # it can be *huge* for large rooms, so it's worth ensuring that we don't - # unnecessarily pull it from the DB). - missing_state_events = set(state_event_ids) - set(event_map) - missing_auth_events = set(auth_event_ids) - set(auth_events_in_store) - if missing_state_events or missing_auth_events: - await self._get_events_and_persist( - destination=destination, - room_id=room_id, - events=missing_state_events | missing_auth_events, - ) - - if missing_state_events: - new_events = await self.store.get_events( - missing_state_events, allow_rejected=True - ) - event_map.update(new_events) - - missing_state_events.difference_update(new_events) - - if missing_state_events: - logger.warning( - "Failed to fetch missing state events for %s %s", - event_id, - missing_state_events, - ) - - if missing_auth_events: - auth_events_in_store = await self.store.have_seen_events( - room_id, missing_auth_events - ) - missing_auth_events.difference_update(auth_events_in_store) - - if missing_auth_events: - logger.warning( - "Failed to fetch missing auth events for %s %s", - event_id, - missing_auth_events, - ) - - remote_state = list(event_map.values()) - - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B - - bad_events = [ - (event.event_id, event.room_id) - for event in remote_state - if event.room_id != room_id - ] - - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned auth/state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) - - if bad_events: - remote_state = [e for e in remote_state if e.room_id == room_id] - - return remote_state + await self._process_pulled_events(origin, missing_events, backfilled=False) async def _get_state_after_missing_prev_event( self, @@ -567,10 +457,6 @@ async def _get_state_after_missing_prev_event( Returns: A list of events in the state, including the event itself """ - # TODO: This function is basically the same as _get_state_for_room. Can - # we make backfill() use it, rather than having two code paths? I think the - # only difference is that backfill() persists the prev events separately. - ( state_event_ids, auth_event_ids, @@ -681,6 +567,7 @@ async def _process_received_pdu( origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + backfilled: bool = False, ) -> None: """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -693,6 +580,9 @@ async def _process_received_pdu( state: Normally None, but if we are handling a gap in the graph (ie, we are missing one or more prev_events), the resolved state at the event + + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) """ logger.debug("Processing event: %s", event) @@ -700,10 +590,15 @@ async def _process_received_pdu( context = await self.state_handler.compute_event_context( event, old_state=state ) - await self._auth_and_persist_event(origin, event, context, state=state) + await self._auth_and_persist_event( + origin, event, context, state=state, backfilled=backfilled + ) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) + if backfilled: + return + # For encrypted messages we check that we know about the sending device, # if we don't then we mark the device cache for that user as stale. if event.type == EventTypes.Encrypted: @@ -868,7 +763,7 @@ async def _resync_device(self, sender: str) -> None: @log_function async def backfill( self, dest: str, room_id: str, limit: int, extremities: List[str] - ) -> List[EventBase]: + ) -> None: """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -878,6 +773,9 @@ async def backfill( sanity-checking on them. If any of the backfilled events are invalid, this method throws a SynapseError. + We might also raise an InvalidResponseError if the response from the remote + server is just bogus. + TODO: make this more useful to distinguish failures of the remote server from invalid events (there is probably no point in trying to re-fetch invalid events from every other HS in the room.) @@ -890,111 +788,18 @@ async def backfill( ) if not events: - return [] - - # ideally we'd sanity check the events here for excess prev_events etc, - # but it's hard to reject events at this point without completely - # breaking backfill in the same way that it is currently broken by - # events whose signature we cannot verify (#3121). - # - # So for now we accept the events anyway. #3124 tracks this. - # - # for ev in events: - # self._sanity_check_event(ev) - - # Don't bother processing events we already have. - seen_events = await self.store.have_events_in_timeline( - {e.event_id for e in events} - ) - - events = [e for e in events if e.event_id not in seen_events] - - if not events: - return [] - - event_map = {e.event_id: e for e in events} - - event_ids = {e.event_id for e in events} - - # build a list of events whose prev_events weren't in the batch. - # (XXX: this will include events whose prev_events we already have; that doesn't - # sound right?) - edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] - - logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) - - # For each edge get the current state. - - state_events = {} - events_to_state = {} - for e_id in edges: - state = await self._get_state_for_room( - destination=dest, - room_id=room_id, - event_id=e_id, - ) - state_events.update({s.event_id: s for s in state}) - events_to_state[e_id] = state + return - required_auth = { - a_id - for event in events + list(state_events.values()) - for a_id in event.auth_event_ids() - } - auth_events = await self.store.get_events(required_auth, allow_rejected=True) - auth_events.update( - {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} - ) - - ev_infos = [] - - # Step 1: persist the events in the chunk we fetched state for (i.e. - # the backwards extremities), with custom auth events and state - for e_id in events_to_state: - # For paranoia we ensure that these events are marked as - # non-outliers - ev = event_map[e_id] - assert not ev.internal_metadata.is_outlier() - - ev_infos.append( - _NewEventInfo( - event=ev, - state=events_to_state[e_id], - claimed_auth_event_map={ - ( - auth_events[a_id].type, - auth_events[a_id].state_key, - ): auth_events[a_id] - for a_id in ev.auth_event_ids() - if a_id in auth_events - }, + # if there are any events in the wrong room, the remote server is buggy and + # should not be trusted. + for ev in events: + if ev.room_id != room_id: + raise InvalidResponseError( + f"Remote server {dest} returned event {ev.event_id} which is in " + f"room {ev.room_id}, when we were backfilling in {room_id}" ) - ) - - if ev_infos: - await self._auth_and_persist_events( - dest, room_id, ev_infos, backfilled=True - ) - - # Step 2: Persist the rest of the events in the chunk one by one - events.sort(key=lambda e: e.depth) - - for event in events: - if event in events_to_state: - continue - - # For paranoia we ensure that these events are marked as - # non-outliers - assert not event.internal_metadata.is_outlier() - - context = await self.state_handler.compute_event_context(event) - - # We store these one at a time since each event depends on the - # previous to work out the state. - # TODO: We can probably do something more clever here. - await self._auth_and_persist_event(dest, event, context, backfilled=True) - return events + await self._process_pulled_events(dest, events, backfilled=True) async def maybe_backfill( self, room_id: str, current_depth: int, limit: int @@ -1197,7 +1002,7 @@ async def try_backfill(domains: List[str]) -> bool: # appropriate stuff. # TODO: We can probably do something more intelligent here. return True - except SynapseError as e: + except (SynapseError, InvalidResponseError) as e: logger.info("Failed to backfill from %s because %s", dom, e) continue except HttpResponseException as e: @@ -1351,7 +1156,7 @@ async def get_event(event_id: str): else: logger.info("Missing auth event %s", auth_event_id) - event_infos.append(_NewEventInfo(event, None, auth)) + event_infos.append(_NewEventInfo(event, auth)) if event_infos: await self._auth_and_persist_events( @@ -1361,7 +1166,7 @@ async def get_event(event_id: str): ) async def _process_pulled_events( - self, origin: str, events: Iterable[EventBase] + self, origin: str, events: Iterable[EventBase], backfilled: bool ) -> None: """Process a batch of events we have pulled from a remote server @@ -1373,6 +1178,8 @@ async def _process_pulled_events( Params: origin: The server we received these events from events: The received events. + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) """ # We want to sort these by depth so we process them and @@ -1381,9 +1188,11 @@ async def _process_pulled_events( for ev in sorted_events: with nested_logging_context(ev.event_id): - await self._process_pulled_event(origin, ev) + await self._process_pulled_event(origin, ev, backfilled=backfilled) - async def _process_pulled_event(self, origin: str, event: EventBase) -> None: + async def _process_pulled_event( + self, origin: str, event: EventBase, backfilled: bool + ) -> None: """Process a single event that we have pulled from a remote server Pulls in any events required to auth the event, persists the received event, @@ -1400,6 +1209,8 @@ async def _process_pulled_event(self, origin: str, event: EventBase) -> None: Params: origin: The server we received this event from events: The received event + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) """ logger.info("Processing pulled event %s", event) @@ -1428,7 +1239,9 @@ async def _process_pulled_event(self, origin: str, event: EventBase) -> None: try: state = await self._resolve_state_at_missing_prevs(origin, event) - await self._process_received_pdu(origin, event, state=state) + await self._process_received_pdu( + origin, event, state=state, backfilled=backfilled + ) except FederationError as e: if e.code == 403: logger.warning("Pulled event %s failed history check.", event_id) @@ -2451,7 +2264,6 @@ async def _auth_and_persist_events( origin: str, room_id: str, event_infos: Collection[_NewEventInfo], - backfilled: bool = False, ) -> None: """Creates the appropriate contexts and persists events. The events should not depend on one another, e.g. this should be used to persist @@ -2467,16 +2279,12 @@ async def _auth_and_persist_events( async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context( - event, old_state=ev_info.state - ) + res = await self.state_handler.compute_event_context(event) res = await self._check_event_auth( origin, event, res, - state=ev_info.state, claimed_auth_event_map=ev_info.claimed_auth_event_map, - backfilled=backfilled, ) return res @@ -2493,7 +2301,6 @@ async def prep(ev_info: _NewEventInfo): (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) ], - backfilled=backfilled, ) async def _persist_auth_tree( diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 664c65dac5a6..bccff5e5b95c 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -295,6 +295,7 @@ def _purge_history_txn( self._invalidate_cache_and_stream( txn, self.have_seen_event, (room_id, event_id) ) + self._invalidate_get_event_cache(event_id) logger.info("[purge] done") From 1800aabfc226938036479d2ab1a750aa34cf3974 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 26 Aug 2021 21:41:44 +0100 Subject: [PATCH 22/28] Split `FederationHandler` in half (#10692) The idea here is to take anything to do with incoming events and move it out to a separate handler, as a way of making FederationHandler smaller. --- changelog.d/10692.misc | 1 + synapse/federation/federation_server.py | 7 +- synapse/handlers/federation.py | 1789 +--------------- synapse/handlers/federation_event.py | 1825 +++++++++++++++++ synapse/replication/http/federation.py | 4 +- synapse/server.py | 5 + tests/federation/transport/test_knocking.py | 2 +- tests/handlers/test_federation.py | 16 +- tests/handlers/test_presence.py | 4 +- .../test_federation_sender_shard.py | 2 +- tests/test_federation.py | 10 +- 11 files changed, 1884 insertions(+), 1781 deletions(-) create mode 100644 changelog.d/10692.misc create mode 100644 synapse/handlers/federation_event.py diff --git a/changelog.d/10692.misc b/changelog.d/10692.misc new file mode 100644 index 000000000000..a1b0def76b20 --- /dev/null +++ b/changelog.d/10692.misc @@ -0,0 +1 @@ +Split the event-processing methods in `FederationHandler` into a separate `FederationEventHandler`. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e1b58d40c533..214ee948fa3d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -110,6 +110,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_federation_handler() + self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -787,7 +788,9 @@ async def _on_send_membership_event( event = await self._check_sigs_and_hash(room_version, event) - return await self.handler.on_send_membership_event(origin, event) + return await self._federation_event_handler.on_send_membership_event( + origin, event + ) async def on_event_auth( self, origin: str, room_id: str, event_id: str @@ -1005,7 +1008,7 @@ async def _process_incoming_pdus_in_room_inner( async with lock: logger.info("handling received PDU: %s", event) try: - await self.handler.on_receive_pdu(origin, event) + await self._federation_event_handler.on_receive_pdu(origin, event) except FederationError as e: # XXX: Ideally we'd inform the remote we failed to process # the event, but we can't return an error in the transaction diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6fa2fc8f5284..daf1d3bfb35a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -17,23 +17,9 @@ import itertools import logging -from collections.abc import Container from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union -import attr -from prometheus_client import Counter from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -41,19 +27,12 @@ from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import ( - EventContentFields, - EventTypes, - Membership, - RejectedReason, - RoomEncryptionAlgorithms, -) +from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.errors import ( AuthError, CodeMessageException, Codes, FederationDeniedError, - FederationError, HttpResponseException, NotFoundError, RequestSendFailed, @@ -61,7 +40,6 @@ ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.crypto.event_signing import compute_event_signature -from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator @@ -75,28 +53,14 @@ run_in_background, ) from synapse.logging.utils import log_function -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, - ReplicationFederationSendEventsRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, ) -from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import ( - JsonDict, - MutableStateMap, - PersistedEventPosition, - RoomStreamToken, - StateMap, - UserID, - get_domain_from_id, -) -from synapse.util.async_helpers import Linearizer, concurrently_execute -from synapse.util.iterutils import batch_iter +from synapse.types import JsonDict, StateMap, get_domain_from_id +from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination -from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server if TYPE_CHECKING: @@ -104,45 +68,11 @@ logger = logging.getLogger(__name__) -soft_failed_event_counter = Counter( - "synapse_federation_soft_failed_events_total", - "Events received over federation that we marked as soft_failed", -) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _NewEventInfo: - """Holds information about a received event, ready for passing to _auth_and_persist_events - - Attributes: - event: the received event - - claimed_auth_event_map: a map of (type, state_key) => event for the event's - claimed auth_events. - - This can include events which have not yet been persisted, in the case that - we are backfilling a batch of events. - - Note: May be incomplete: if we were unable to find all of the claimed auth - events. Also, treat the contents with caution: the events might also have - been rejected, might not yet have been authorized themselves, or they might - be in the wrong room. - - """ - - event: EventBase - claimed_auth_event_map: StateMap[EventBase] - class FederationHandler(BaseHandler): - """Handles events that originated from federation. - Responsible for: - a) handling received Pdus before handing them on as Events to the rest - of the homeserver (including auth and state conflict resolutions) - b) converting events that were produced by local clients that may need - to be sent to remote homeservers. - c) doing the necessary dances to invite remote users and join remote - rooms. + """Handles general incoming federation requests + + Incoming events are *not* handled here, for which see FederationEventHandler. """ def __init__(self, hs: "HomeServer"): @@ -155,652 +85,35 @@ def __init__(self, hs: "HomeServer"): self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() - self._state_resolution_handler = hs.get_state_resolution_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() - self.action_generator = hs.get_action_generator() self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self._event_auth_handler = hs.get_event_auth_handler() - self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_proxied_blacklisted_http_client() - self._instance_name = hs.get_instance_name() self._replication = hs.get_replication_data_handler() + self._federation_event_handler = hs.get_federation_event_handler() - self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( hs ) if hs.config.worker_app: - self._user_device_resync = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) self._maybe_store_room_on_outlier_membership = ( ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) ) else: - self._device_list_updater = hs.get_device_handler().device_list_updater self._maybe_store_room_on_outlier_membership = ( self.store.maybe_store_room_on_outlier_membership ) - # When joining a room we need to queue any events for that room up. - # For each room, a list of (pdu, origin) tuples. - self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {} - self._room_pdu_linearizer = Linearizer("fed_room_pdu") - self._room_backfill = Linearizer("room_backfill") self.third_party_event_rules = hs.get_third_party_event_rules() - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - - async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: - """Process a PDU received via a federation /send/ transaction - - Args: - origin: server which initiated the /send/ transaction. Will - be used to fetch missing events or state. - pdu: received PDU - """ - - room_id = pdu.room_id - event_id = pdu.event_id - - # We reprocess pdus when we have seen them only as outliers - existing = await self.store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - - # FIXME: Currently we fetch an event again when we already have it - # if it has been marked as an outlier. - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "Ignoring received event %s which we have already seen", event_id - ) - return - if pdu.internal_metadata.is_outlier(): - logger.info( - "Ignoring received outlier %s which we already have as an outlier", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - - # do some initial sanity-checking of the event. In particular, make - # sure it doesn't have hundreds of prev_events or auth_events, which - # could cause a huge state resolution or cascade of event fetches. - try: - self._sanity_check_event(pdu) - except SynapseError as err: - logger.warning("Received event failed sanity checks") - raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) - - # If we are currently in the process of joining this room, then we - # queue up events for later processing. - if room_id in self.room_queues: - logger.info( - "Queuing PDU from %s for now: join in progress", - origin, - ) - self.room_queues[room_id].append((pdu, origin)) - return - - # If we're not in the room just ditch the event entirely. This is - # probably an old server that has come back and thinks we're still in - # the room (or we've been rejoined to the room by a state reset). - # - # Note that if we were never in the room then we would have already - # dropped the event, since we wouldn't know the room version. - is_in_room = await self._event_auth_handler.check_host_in_room( - room_id, self.server_name - ) - if not is_in_room: - logger.info( - "Ignoring PDU from %s as we're not in the room", - origin, - ) - return None - - # Check that the event passes auth based on the state at the event. This is - # done for events that are to be added to the timeline (non-outliers). - # - # Get missing pdus if necessary: - # - Fetching any missing prev events to fill in gaps in the graph - # - Fetching state if we have a hole in the graph - if not pdu.internal_metadata.is_outlier(): - prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if missing_prevs: - # We only backfill backwards to the min depth. - min_depth = await self.get_min_depth_for_context(pdu.room_id) - logger.debug("min_depth: %d", min_depth) - - if min_depth is not None and pdu.depth > min_depth: - # If we're missing stuff, ensure we only fetch stuff one - # at a time. - logger.info( - "Acquiring room lock to fetch %d missing prev_events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - with (await self._room_pdu_linearizer.queue(pdu.room_id)): - logger.info( - "Acquired room lock to fetch %d missing prev_events", - len(missing_prevs), - ) - - try: - await self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) - except Exception as e: - raise Exception( - "Error fetching missing prev_events for %s: %s" - % (event_id, e) - ) from e - - # Update the set of things we've seen after trying to - # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if not missing_prevs: - logger.info("Found all missing prev_events") - - if missing_prevs: - # since this event was pushed to us, it is possible for it to - # become the only forward-extremity in the room, and we would then - # trust its state to be the state for the whole room. This is very - # bad. Further, if the event was pushed to us, there is no excuse - # for us not to have all the prev_events. (XXX: apart from - # min_depth?) - # - # We therefore reject any such events. - logger.warning( - "Rejecting: failed to fetch %d prev events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) - - await self._process_received_pdu(origin, pdu, state=None) - - async def _get_missing_events_for_pdu( - self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int - ) -> None: - """ - Args: - origin: Origin of the pdu. Will be called to get the missing events - pdu: received pdu - prevs: List of event ids which we are missing - min_depth: Minimum depth of events to return. - """ - - room_id = pdu.room_id - event_id = pdu.event_id - - seen = await self.store.have_events_in_timeline(prevs) - - if not prevs - seen: - return - - latest_list = await self.store.get_latest_event_ids_in_room(room_id) - - # We add the prev events that we have seen to the latest - # list to ensure the remote server doesn't give them to us - latest = set(latest_list) - latest |= seen - - logger.info( - "Requesting missing events between %s and %s", - shortstr(latest), - event_id, - ) - - # XXX: we set timeout to 10s to help workaround - # https://github.com/matrix-org/synapse/issues/1733. - # The reason is to avoid holding the linearizer lock - # whilst processing inbound /send transactions, causing - # FDs to stack up and block other inbound transactions - # which empirically can currently take up to 30 minutes. - # - # N.B. this explicitly disables retry attempts. - # - # N.B. this also increases our chances of falling back to - # fetching fresh state for the room if the missing event - # can't be found, which slightly reduces our security. - # it may also increase our DAG extremity count for the room, - # causing additional state resolution? See #1760. - # However, fetching state doesn't hold the linearizer lock - # apparently. - # - # see https://github.com/matrix-org/synapse/pull/1744 - # - # ---- - # - # Update richvdh 2018/09/18: There are a number of problems with timing this - # request out aggressively on the client side: - # - # - it plays badly with the server-side rate-limiter, which starts tarpitting you - # if you send too many requests at once, so you end up with the server carefully - # working through the backlog of your requests, which you have already timed - # out. - # - # - for this request in particular, we now (as of - # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the - # server can't produce a plausible-looking set of prev_events - so we becone - # much more likely to reject the event. - # - # - contrary to what it says above, we do *not* fall back to fetching fresh state - # for the room if get_missing_events times out. Rather, we give up processing - # the PDU whose prevs we are missing, which then makes it much more likely that - # we'll end up back here for the *next* PDU in the list, which exacerbates the - # problem. - # - # - the aggressive 10s timeout was introduced to deal with incoming federation - # requests taking 8 hours to process. It's not entirely clear why that was going - # on; certainly there were other issues causing traffic storms which are now - # resolved, and I think in any case we may be more sensible about our locking - # now. We're *certainly* more sensible about our logging. - # - # All that said: Let's try increasing the timeout to 60s and see what happens. - - try: - missing_events = await self.federation_client.get_missing_events( - origin, - room_id, - earliest_events_ids=list(latest), - latest_events=[pdu], - limit=10, - min_depth=min_depth, - timeout=60000, - ) - except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e: - # We failed to get the missing events, but since we need to handle - # the case of `get_missing_events` not returning the necessary - # events anyway, it is safe to simply log the error and continue. - logger.warning("Failed to get prev_events: %s", e) - return - - logger.info("Got %d prev_events", len(missing_events)) - await self._process_pulled_events(origin, missing_events, backfilled=False) - - async def _get_state_after_missing_prev_event( - self, - destination: str, - room_id: str, - event_id: str, - ) -> List[EventBase]: - """Requests all of the room state at a given event from a remote homeserver. - - Args: - destination: The remote homeserver to query for the state. - room_id: The id of the room we're interested in. - event_id: The id of the event we want the state at. - - Returns: - A list of events in the state, including the event itself - """ - ( - state_event_ids, - auth_event_ids, - ) = await self.federation_client.get_room_state_ids( - destination, room_id, event_id=event_id - ) - - logger.debug( - "state_ids returned %i state events, %i auth events", - len(state_event_ids), - len(auth_event_ids), - ) - - # start by just trying to fetch the events from the store - desired_events = set(state_event_ids) - desired_events.add(event_id) - logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self.store.get_events( - desired_events, allow_rejected=True - ) - - missing_desired_events = desired_events - fetched_events.keys() - logger.debug( - "We are missing %i events (got %i)", - len(missing_desired_events), - len(fetched_events), - ) - - # We probably won't need most of the auth events, so let's just check which - # we have for now, rather than thrashing the event cache with them all - # unnecessarily. - - # TODO: we probably won't actually need all of the auth events, since we - # already have a bunch of the state events. It would be nice if the - # federation api gave us a way of finding out which we actually need. - - missing_auth_events = set(auth_event_ids) - fetched_events.keys() - missing_auth_events.difference_update( - await self.store.have_seen_events(room_id, missing_auth_events) - ) - logger.debug("We are also missing %i auth events", len(missing_auth_events)) - - missing_events = missing_desired_events | missing_auth_events - logger.debug("Fetching %i events from remote", len(missing_events)) - await self._get_events_and_persist( - destination=destination, room_id=room_id, events=missing_events - ) - - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - await self.store.get_events(missing_desired_events, allow_rejected=True) - ) - - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B - - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] - - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) - - del fetched_events[bad_event_id] - - # if we couldn't get the prev event in question, that's a problem. - remote_event = fetched_events.get(event_id) - if not remote_event: - raise Exception("Unable to get missing prev_event %s" % (event_id,)) - - # missing state at that event is a warning, not a blocker - # XXX: this doesn't sound right? it means that we'll end up with incomplete - # state. - failed_to_fetch = desired_events - fetched_events.keys() - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state events for %s %s", - event_id, - failed_to_fetch, - ) - - remote_state = [ - fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events - ] - - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) - - return remote_state - - async def _process_received_pdu( - self, - origin: str, - event: EventBase, - state: Optional[Iterable[EventBase]], - backfilled: bool = False, - ) -> None: - """Called when we have a new pdu. We need to do auth checks and put it - through the StateHandler. - - Args: - origin: server sending the event - - event: event to be persisted - - state: Normally None, but if we are handling a gap in the graph - (ie, we are missing one or more prev_events), the resolved state at the - event - - backfilled: True if this is part of a historical batch of events (inhibits - notification to clients, and validation of device keys.) - """ - logger.debug("Processing event: %s", event) - - try: - context = await self.state_handler.compute_event_context( - event, old_state=state - ) - await self._auth_and_persist_event( - origin, event, context, state=state, backfilled=backfilled - ) - except AuthError as e: - raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - - if backfilled: - return - - # For encrypted messages we check that we know about the sending device, - # if we don't then we mark the device cache for that user as stale. - if event.type == EventTypes.Encrypted: - device_id = event.content.get("device_id") - sender_key = event.content.get("sender_key") - - cached_devices = await self.store.get_cached_devices_for_user(event.sender) - - resync = False # Whether we should resync device lists. - - device = None - if device_id is not None: - device = cached_devices.get(device_id) - if device is None: - logger.info( - "Received event from remote device not in our cache: %s %s", - event.sender, - device_id, - ) - resync = True - - # We also check if the `sender_key` matches what we expect. - if sender_key is not None: - # Figure out what sender key we're expecting. If we know the - # device and recognize the algorithm then we can work out the - # exact key to expect. Otherwise check it matches any key we - # have for that device. - - current_keys: Container[str] = [] - - if device: - keys = device.get("keys", {}).get("keys", {}) - - if ( - event.content.get("algorithm") - == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2 - ): - # For this algorithm we expect a curve25519 key. - key_name = "curve25519:%s" % (device_id,) - current_keys = [keys.get(key_name)] - else: - # We don't know understand the algorithm, so we just - # check it matches a key for the device. - current_keys = keys.values() - elif device_id: - # We don't have any keys for the device ID. - pass - else: - # The event didn't include a device ID, so we just look for - # keys across all devices. - current_keys = [ - key - for device in cached_devices.values() - for key in device.get("keys", {}).get("keys", {}).values() - ] - - # We now check that the sender key matches (one of) the expected - # keys. - if sender_key not in current_keys: - logger.info( - "Received event from remote device with unexpected sender key: %s %s: %s", - event.sender, - device_id or "", - sender_key, - ) - resync = True - - if resync: - run_as_background_process( - "resync_device_due_to_pdu", self._resync_device, event.sender - ) - - await self._handle_marker_event(origin, event) - - async def _handle_marker_event(self, origin: str, marker_event: EventBase): - """Handles backfilling the insertion event when we receive a marker - event that points to one. - - Args: - origin: Origin of the event. Will be called to get the insertion event - marker_event: The event to process - """ - - if marker_event.type != EventTypes.MSC2716_MARKER: - # Not a marker event - return - - if marker_event.rejected_reason is not None: - # Rejected event - return - - # Skip processing a marker event if the room version doesn't - # support it. - room_version = await self.store.get_room_version(marker_event.room_id) - if not room_version.msc2716_historical: - return - - logger.debug("_handle_marker_event: received %s", marker_event) - - insertion_event_id = marker_event.content.get( - EventContentFields.MSC2716_MARKER_INSERTION - ) - - if insertion_event_id is None: - # Nothing to retrieve then (invalid marker) - return - - logger.debug( - "_handle_marker_event: backfilling insertion event %s", insertion_event_id - ) - - await self._get_events_and_persist( - origin, - marker_event.room_id, - [insertion_event_id], - ) - - insertion_event = await self.store.get_event( - insertion_event_id, allow_none=True - ) - if insertion_event is None: - logger.warning( - "_handle_marker_event: server %s didn't return insertion event %s for marker %s", - origin, - insertion_event_id, - marker_event.event_id, - ) - return - - logger.debug( - "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", - insertion_event, - marker_event, - ) - - await self.store.insert_insertion_extremity( - insertion_event_id, marker_event.room_id - ) - - logger.debug( - "_handle_marker_event: insertion extremity added for %s from marker event %s", - insertion_event, - marker_event, - ) - - async def _resync_device(self, sender: str) -> None: - """We have detected that the device list for the given user may be out - of sync, so we try and resync them. - """ - - try: - await self.store.mark_remote_user_device_cache_as_stale(sender) - - # Immediately attempt a resync in the background - if self.config.worker_app: - await self._user_device_resync(user_id=sender) - else: - await self._device_list_updater.user_device_resync(sender) - except Exception: - logger.exception("Failed to resync device for %s", sender) - - @log_function - async def backfill( - self, dest: str, room_id: str, limit: int, extremities: List[str] - ) -> None: - """Trigger a backfill request to `dest` for the given `room_id` - - This will attempt to get more events from the remote. If the other side - has no new events to offer, this will return an empty list. - - As the events are received, we check their signatures, and also do some - sanity-checking on them. If any of the backfilled events are invalid, - this method throws a SynapseError. - - We might also raise an InvalidResponseError if the response from the remote - server is just bogus. - - TODO: make this more useful to distinguish failures of the remote - server from invalid events (there is probably no point in trying to - re-fetch invalid events from every other HS in the room.) - """ - if dest == self.server_name: - raise SynapseError(400, "Can't backfill from self.") - - events = await self.federation_client.backfill( - dest, room_id, limit=limit, extremities=extremities - ) - - if not events: - return - - # if there are any events in the wrong room, the remote server is buggy and - # should not be trusted. - for ev in events: - if ev.room_id != room_id: - raise InvalidResponseError( - f"Remote server {dest} returned event {ev.event_id} which is in " - f"room {ev.room_id}, when we were backfilling in {room_id}" - ) - - await self._process_pulled_events(dest, events, backfilled=True) - async def maybe_backfill( self, room_id: str, current_depth: int, limit: int ) -> bool: @@ -995,7 +308,7 @@ async def try_backfill(domains: List[str]) -> bool: # TODO: Should we try multiple of these at a time? for dom in domains: try: - await self.backfill( + await self._federation_event_handler.backfill( dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the @@ -1084,316 +397,7 @@ async def try_backfill(domains: List[str]) -> bool: tried_domains.update(dom for dom, _ in likely_extremeties_domains) - return False - - async def _get_events_and_persist( - self, destination: str, room_id: str, events: Iterable[str] - ) -> None: - """Fetch the given events from a server, and persist them as outliers. - - This function *does not* recursively get missing auth events of the - newly fetched events. Callers must include in the `events` argument - any missing events from the auth chain. - - Logs a warning if we can't find the given event. - """ - - room_version = await self.store.get_room_version(room_id) - - event_map: Dict[str, EventBase] = {} - - async def get_event(event_id: str): - with nested_logging_context(event_id): - try: - event = await self.federation_client.get_pdu( - [destination], - event_id, - room_version, - outlier=True, - ) - if event is None: - logger.warning( - "Server %s didn't return event %s", - destination, - event_id, - ) - return - - event_map[event.event_id] = event - - except Exception as e: - logger.warning( - "Error fetching missing state/auth event %s: %s %s", - event_id, - type(e), - e, - ) - - await concurrently_execute(get_event, events, 5) - - # Make a map of auth events for each event. We do this after fetching - # all the events as some of the events' auth events will be in the list - # of requested events. - - auth_events = [ - aid - for event in event_map.values() - for aid in event.auth_event_ids() - if aid not in event_map - ] - persisted_events = await self.store.get_events( - auth_events, - allow_rejected=True, - ) - - event_infos = [] - for event in event_map.values(): - auth = {} - for auth_event_id in event.auth_event_ids(): - ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id) - if ae: - auth[(ae.type, ae.state_key)] = ae - else: - logger.info("Missing auth event %s", auth_event_id) - - event_infos.append(_NewEventInfo(event, auth)) - - if event_infos: - await self._auth_and_persist_events( - destination, - room_id, - event_infos, - ) - - async def _process_pulled_events( - self, origin: str, events: Iterable[EventBase], backfilled: bool - ) -> None: - """Process a batch of events we have pulled from a remote server - - Pulls in any events required to auth the events, persists the received events, - and notifies clients, if appropriate. - - Assumes the events have already had their signatures and hashes checked. - - Params: - origin: The server we received these events from - events: The received events. - backfilled: True if this is part of a historical batch of events (inhibits - notification to clients, and validation of device keys.) - """ - - # We want to sort these by depth so we process them and - # tell clients about them in order. - sorted_events = sorted(events, key=lambda x: x.depth) - - for ev in sorted_events: - with nested_logging_context(ev.event_id): - await self._process_pulled_event(origin, ev, backfilled=backfilled) - - async def _process_pulled_event( - self, origin: str, event: EventBase, backfilled: bool - ) -> None: - """Process a single event that we have pulled from a remote server - - Pulls in any events required to auth the event, persists the received event, - and notifies clients, if appropriate. - - Assumes the event has already had its signatures and hashes checked. - - This is somewhat equivalent to on_receive_pdu, but applies somewhat different - logic in the case that we are missing prev_events (in particular, it just - requests the state at that point, rather than triggering a get_missing_events) - - so is appropriate when we have pulled the event from a remote server, rather - than having it pushed to us. - - Params: - origin: The server we received this event from - events: The received event - backfilled: True if this is part of a historical batch of events (inhibits - notification to clients, and validation of device keys.) - """ - logger.info("Processing pulled event %s", event) - - # these should not be outliers. - assert not event.internal_metadata.is_outlier() - - event_id = event.event_id - - existing = await self.store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "Ignoring received event %s which we have already seen", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - - try: - self._sanity_check_event(event) - except SynapseError as err: - logger.warning("Event %s failed sanity check: %s", event_id, err) - return - - try: - state = await self._resolve_state_at_missing_prevs(origin, event) - await self._process_received_pdu( - origin, event, state=state, backfilled=backfilled - ) - except FederationError as e: - if e.code == 403: - logger.warning("Pulled event %s failed history check.", event_id) - else: - raise - - async def _resolve_state_at_missing_prevs( - self, dest: str, event: EventBase - ) -> Optional[Iterable[EventBase]]: - """Calculate the state at an event with missing prev_events. - - This is used when we have pulled a batch of events from a remote server, and - still don't have all the prev_events. - - If we already have all the prev_events for `event`, this method does nothing. - - Otherwise, the missing prevs become new backwards extremities, and we fall back - to asking the remote server for the state after each missing `prev_event`, - and resolving across them. - - That's ok provided we then resolve the state against other bits of the DAG - before using it - in other words, that the received event `event` is not going - to become the only forwards_extremity in the room (which will ensure that you - can't just take over a room by sending an event, withholding its prev_events, - and declaring yourself to be an admin in the subsequent state request). - - In other words: we should only call this method if `event` has been *pulled* - as part of a batch of missing prev events, or similar. - - Params: - dest: the remote server to ask for state at the missing prevs. Typically, - this will be the server we got `event` from. - event: an event to check for missing prevs. - - Returns: - if we already had all the prev events, `None`. Otherwise, returns a list of - the events in the state at `event`. - """ - room_id = event.room_id - event_id = event.event_id - - prevs = set(event.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if not missing_prevs: - return None - - logger.info( - "Event %s is missing prev_events %s: calculating state for a " - "backwards extremity", - event_id, - shortstr(missing_prevs), - ) - # Calculate the state after each of the previous events, and - # resolve them to find the correct state at the current event. - event_map = {event_id: event} - try: - # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids(room_id, seen) - - # state_maps is a list of mappings from (type, state_key) to event_id - state_maps: List[StateMap[str]] = list(ours.values()) - - # we don't need this any more, let's delete it. - del ours - - # Ask the remote server for the states we don't - # know about - for p in missing_prevs: - logger.info("Requesting state after missing prev_event %s", p) - - with nested_logging_context(p): - # note that if any of the missing prevs share missing state or - # auth events, the requests to fetch those events are deduped - # by the get_pdu_cache in federation_client. - remote_state = await self._get_state_after_missing_prev_event( - dest, room_id, p - ) - - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } - state_maps.append(remote_state_map) - - for x in remote_state: - event_map[x.event_id] = x - - room_version = await self.store.get_room_version_id(room_id) - state_map = await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), - ) - - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] - except Exception: - logger.warning( - "Error attempting to resolve state at missing prev_events", - exc_info=True, - ) - raise FederationError( - "ERROR", - 403, - "We can't get valid state history.", - affected=event_id, - ) - return state - - def _sanity_check_event(self, ev: EventBase) -> None: - """ - Do some early sanity checks of a received event - - In particular, checks it doesn't have an excessive number of - prev_events or auth_events, which could cause a huge state resolution - or cascade of event fetches. - - Args: - ev: event to be checked - - Raises: - SynapseError if the event does not pass muster - """ - if len(ev.prev_event_ids()) > 20: - logger.warning( - "Rejecting event %s which has %i prev_events", - ev.event_id, - len(ev.prev_event_ids()), - ) - raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") - - if len(ev.auth_event_ids()) > 10: - logger.warning( - "Rejecting event %s which has %i auth_events", - ev.event_id, - len(ev.auth_event_ids()), - ) - raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") + return False async def send_invite(self, target_host: str, event: EventBase) -> EventBase: """Sends the invite to the remote server for signing. @@ -1460,9 +464,9 @@ async def do_invite_join( # This shouldn't happen, because the RoomMemberHandler has a # linearizer lock which only allows one operation per user per room # at a time - so this is just paranoia. - assert room_id not in self.room_queues + assert room_id not in self._federation_event_handler.room_queues - self.room_queues[room_id] = [] + self._federation_event_handler.room_queues[room_id] = [] await self._clean_room_for_join(room_id) @@ -1536,8 +540,8 @@ async def do_invite_join( logger.debug("Finished joining %s to %s", joinee, room_id) return event.event_id, max_stream_id finally: - room_queue = self.room_queues[room_id] - del self.room_queues[room_id] + room_queue = self._federation_event_handler.room_queues[room_id] + del self._federation_event_handler.room_queues[room_id] # we don't need to wait for the queued events to be processed - # it's just a best-effort thing at this point. We do want to do @@ -1613,7 +617,7 @@ async def do_knock( event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify( + stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) return event.event_id, stream_id @@ -1633,7 +637,7 @@ async def _handle_queued_pdus( p, ) with nested_logging_context(p.event_id): - await self.on_receive_pdu(origin, p) + await self._federation_event_handler.on_receive_pdu(origin, p) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e @@ -1726,7 +730,7 @@ async def on_make_join_request( raise # Ensure the user can even join the room. - await self._check_join_restrictions(context, event) + await self._federation_event_handler.check_join_restrictions(context, event) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` @@ -1803,7 +807,9 @@ async def on_invite_request( ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify(event.room_id, [(event, context)]) + await self._federation_event_handler.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event @@ -1830,7 +836,7 @@ async def do_remotely_reject_invite( await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify( + stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -1973,116 +979,6 @@ async def on_make_knock_request( return event - @log_function - async def on_send_membership_event( - self, origin: str, event: EventBase - ) -> Tuple[EventBase, EventContext]: - """ - We have received a join/leave/knock event for a room via send_join/leave/knock. - - Verify that event and send it into the room on the remote homeserver's behalf. - - This is quite similar to on_receive_pdu, with the following principal - differences: - * only membership events are permitted (and only events with - sender==state_key -- ie, no kicks or bans) - * *We* send out the event on behalf of the remote server. - * We enforce the membership restrictions of restricted rooms. - * Rejected events result in an exception rather than being stored. - - There are also other differences, however it is not clear if these are by - design or omission. In particular, we do not attempt to backfill any missing - prev_events. - - Args: - origin: The homeserver of the remote (joining/invited/knocking) user. - event: The member event that has been signed by the remote homeserver. - - Returns: - The event and context of the event after inserting it into the room graph. - - Raises: - SynapseError if the event is not accepted into the room - """ - logger.debug( - "on_send_membership_event: Got event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - if get_domain_from_id(event.sender) != origin: - logger.info( - "Got send_membership request for user %r from different origin %s", - event.sender, - origin, - ) - raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - - if event.sender != event.state_key: - raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON) - - assert not event.internal_metadata.outlier - - # Send this event on behalf of the other server. - # - # The remote server isn't a full participant in the room at this point, so - # may not have an up-to-date list of the other homeservers participating in - # the room, so we send it on their behalf. - event.internal_metadata.send_on_behalf_of = origin - - context = await self.state_handler.compute_event_context(event) - context = await self._check_event_auth(origin, event, context) - if context.rejected: - raise SynapseError( - 403, f"{event.membership} event was rejected", Codes.FORBIDDEN - ) - - # for joins, we need to check the restrictions of restricted rooms - if event.membership == Membership.JOIN: - await self._check_join_restrictions(context, event) - - # for knock events, we run the third-party event rules. It's not entirely clear - # why we don't do this for other sorts of membership events. - if event.membership == Membership.KNOCK: - event_allowed, _ = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Sending of knock %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - - # all looks good, we can persist the event. - await self._run_push_actions_and_persist_event(event, context) - return event, context - - async def _check_join_restrictions( - self, context: EventContext, event: EventBase - ) -> None: - """Check that restrictions in restricted join rules are matched - - Called when we receive a join event via send_join. - - Raises an auth error if the restrictions are not matched. - """ - prev_state_ids = await context.get_prev_state_ids() - - # Check if the user is already in the room or invited to the room. - user_id = event.state_key - prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - prev_member_event = None - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - - # Check if the member should be allowed access via membership in a space. - await self._event_auth_handler.check_restricted_join_rules( - prev_state_ids, - event.room_version, - user_id, - prev_member_event, - ) - async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: """Returns the state at the event. i.e. not including said event.""" @@ -2183,126 +1079,6 @@ async def get_persisted_pdu( else: return None - async def get_min_depth_for_context(self, context: str) -> int: - return await self.store.get_min_depth(context) - - async def _auth_and_persist_event( - self, - origin: str, - event: EventBase, - context: EventContext, - state: Optional[Iterable[EventBase]] = None, - claimed_auth_event_map: Optional[StateMap[EventBase]] = None, - backfilled: bool = False, - ) -> None: - """ - Process an event by performing auth checks and then persisting to the database. - - Args: - origin: The host the event originates from. - event: The event itself. - context: - The event context. - - state: - The state events used to check the event for soft-fail. If this is - not provided the current state events will be used. - - claimed_auth_event_map: - A map of (type, state_key) => event for the event's claimed auth_events. - Possibly incomplete, and possibly including events that are not yet - persisted, or authed, or in the right room. - - Only populated where we may not already have persisted these events - - for example, when populating outliers. - - backfilled: True if the event was backfilled. - """ - context = await self._check_event_auth( - origin, - event, - context, - state=state, - claimed_auth_event_map=claimed_auth_event_map, - backfilled=backfilled, - ) - - await self._run_push_actions_and_persist_event(event, context, backfilled) - - async def _run_push_actions_and_persist_event( - self, event: EventBase, context: EventContext, backfilled: bool = False - ): - """Run the push actions for a received event, and persist it. - - Args: - event: The event itself. - context: The event context. - backfilled: True if the event was backfilled. - """ - try: - if ( - not event.internal_metadata.is_outlier() - and not backfilled - and not context.rejected - and (await self.store.get_min_depth(event.room_id)) <= event.depth - ): - await self.action_generator.handle_push_actions_for_event( - event, context - ) - - await self.persist_events_and_notify( - event.room_id, [(event, context)], backfilled=backfilled - ) - except Exception: - run_in_background( - self.store.remove_push_actions_from_staging, event.event_id - ) - raise - - async def _auth_and_persist_events( - self, - origin: str, - room_id: str, - event_infos: Collection[_NewEventInfo], - ) -> None: - """Creates the appropriate contexts and persists events. The events - should not depend on one another, e.g. this should be used to persist - a bunch of outliers, but not a chunk of individual events that depend - on each other for state calculations. - - Notifies about the events where appropriate. - """ - - if not event_infos: - return - - async def prep(ev_info: _NewEventInfo): - event = ev_info.event - with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context(event) - res = await self._check_event_auth( - origin, - event, - res, - claimed_auth_event_map=ev_info.claimed_auth_event_map, - ) - return res - - contexts = await make_deferred_yieldable( - defer.gatherResults( - [run_in_background(prep, ev_info) for ev_info in event_infos], - consumeErrors=True, - ) - ) - - await self.persist_events_and_notify( - room_id, - [ - (ev_info.event, context) - for ev_info, context in zip(event_infos, contexts) - ], - ) - async def _persist_auth_tree( self, origin: str, @@ -2400,7 +1176,7 @@ async def _persist_auth_tree( events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR if auth_events or state: - await self.persist_events_and_notify( + await self._federation_event_handler.persist_events_and_notify( room_id, [ (e, events_to_context[e.event_id]) @@ -2412,108 +1188,10 @@ async def _persist_auth_tree( event, old_state=state ) - return await self.persist_events_and_notify( + return await self._federation_event_handler.persist_events_and_notify( room_id, [(event, new_event_context)] ) - async def _check_for_soft_fail( - self, - event: EventBase, - state: Optional[Iterable[EventBase]], - backfilled: bool, - origin: str, - ) -> None: - """Checks if we should soft fail the event; if so, marks the event as - such. - - Args: - event - state: The state at the event if we don't have all the event's prev events - backfilled: Whether the event is from backfill - origin: The host the event originates from. - """ - # For new (non-backfilled and non-outlier) events we check if the event - # passes auth based on the current state. If it doesn't then we - # "soft-fail" the event. - if backfilled or event.internal_metadata.is_outlier(): - return - - extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) - extrem_ids = set(extrem_ids_list) - prev_event_ids = set(event.prev_event_ids()) - - if extrem_ids == prev_event_ids: - # If they're the same then the current state is the same as the - # state at the event, so no point rechecking auth for soft fail. - return - - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - # Calculate the "current state". - if state is not None: - # If we're explicitly given the state then we won't have all the - # prev events, and so we have a gap in the graph. In this case - # we want to be a little careful as we might have been down for - # a while and have an incorrect view of the current state, - # however we still want to do checks as gaps are easy to - # maliciously manufacture. - # - # So we use a "current state" that is actually a state - # resolution across the current forward extremities and the - # given state at the event. This should correctly handle cases - # like bans, especially with state res v2. - - state_sets_d = await self.state_store.get_state_groups( - event.room_id, extrem_ids - ) - state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) - state_sets.append(state) - current_states = await self.state_handler.resolve_events( - room_version, state_sets, event - ) - current_state_ids: StateMap[str] = { - k: e.event_id for k, e in current_states.items() - } - else: - current_state_ids = await self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids - ) - - logger.debug( - "Doing soft-fail check for %s: state %s", - event.event_id, - current_state_ids, - ) - - # Now check if event pass auth against said current state - auth_types = auth_types_for_event(room_version_obj, event) - current_state_ids_list = [ - e for k, e in current_state_ids.items() if k in auth_types - ] - - auth_events_map = await self.store.get_events(current_state_ids_list) - current_auth_events = { - (e.type, e.state_key): e for e in auth_events_map.values() - } - - try: - event_auth.check(room_version_obj, event, auth_events=current_auth_events) - except AuthError as e: - logger.warning( - "Soft-failing %r (from %s) because %s", - event, - e, - origin, - extra={ - "room_id": event.room_id, - "mxid": event.sender, - "hs": origin, - }, - ) - soft_failed_event_counter.inc() - event.internal_metadata.soft_failed = True - async def on_get_missing_events( self, origin: str, @@ -2542,334 +1220,6 @@ async def on_get_missing_events( return missing_events - async def _check_event_auth( - self, - origin: str, - event: EventBase, - context: EventContext, - state: Optional[Iterable[EventBase]] = None, - claimed_auth_event_map: Optional[StateMap[EventBase]] = None, - backfilled: bool = False, - ) -> EventContext: - """ - Checks whether an event should be rejected (for failing auth checks). - - Args: - origin: The host the event originates from. - event: The event itself. - context: - The event context. - - state: - The state events used to check the event for soft-fail. If this is - not provided the current state events will be used. - - claimed_auth_event_map: - A map of (type, state_key) => event for the event's claimed auth_events. - Possibly incomplete, and possibly including events that are not yet - persisted, or authed, or in the right room. - - Only populated where we may not already have persisted these events - - for example, when populating outliers, or the state for a backwards - extremity. - - backfilled: True if the event was backfilled. - - Returns: - The updated context object. - """ - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - if claimed_auth_event_map: - # if we have a copy of the auth events from the event, use that as the - # basis for auth. - auth_events = claimed_auth_event_map - else: - # otherwise, we calculate what the auth events *should* be, and use that - prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_x = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - - try: - ( - context, - auth_events_for_auth, - ) = await self._update_auth_events_and_context_for_auth( - origin, event, context, auth_events - ) - except Exception: - # We don't really mind if the above fails, so lets not fail - # processing if it does. However, it really shouldn't fail so - # let's still log as an exception since we'll still want to fix - # any bugs. - logger.exception( - "Failed to double check auth events for %s with remote. " - "Ignoring failure and continuing processing of event.", - event.event_id, - ) - auth_events_for_auth = auth_events - - try: - event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) - except AuthError as e: - logger.warning("Failed auth resolution for %r because %s", event, e) - context.rejected = RejectedReason.AUTH_ERROR - - if not context.rejected: - await self._check_for_soft_fail(event, state, backfilled, origin=origin) - - if event.type == EventTypes.GuestAccess and not context.rejected: - await self.maybe_kick_guest_users(event) - - # If we are going to send this event over federation we precaclculate - # the joined hosts. - if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event( - event, context - ) - - return context - - async def _update_auth_events_and_context_for_auth( - self, - origin: str, - event: EventBase, - context: EventContext, - input_auth_events: StateMap[EventBase], - ) -> Tuple[EventContext, StateMap[EventBase]]: - """Helper for _check_event_auth. See there for docs. - - Checks whether a given event has the expected auth events. If it - doesn't then we talk to the remote server to compare state to see if - we can come to a consensus (e.g. if one server missed some valid - state). - - This attempts to resolve any potential divergence of state between - servers, but is not essential and so failures should not block further - processing of the event. - - Args: - origin: - event: - context: - - input_auth_events: - Map from (event_type, state_key) to event - - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. - - Returns: - updated context, updated auth event map - """ - # take a copy of input_auth_events before we modify it. - auth_events: MutableStateMap[EventBase] = dict(input_auth_events) - - event_auth_events = set(event.auth_event_ids()) - - # missing_auth is the set of the event's auth_events which we don't yet have - # in auth_events. - missing_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() - ) - - # if we have missing events, we need to fetch those events from somewhere. - # - # we start by checking if they are in the store, and then try calling /event_auth/. - if missing_auth: - have_events = await self.store.have_seen_events(event.room_id, missing_auth) - logger.debug("Events %s are in the store", have_events) - missing_auth.difference_update(have_events) - - if missing_auth: - # If we don't have all the auth events, we need to get them. - logger.info("auth_events contains unknown events: %s", missing_auth) - try: - try: - remote_auth_chain = await self.federation_client.get_event_auth( - origin, event.room_id, event.event_id - ) - except RequestSendFailed as e1: - # The other side isn't around or doesn't implement the - # endpoint, so lets just bail out. - logger.info("Failed to get event auth from remote: %s", e1) - return context, auth_events - - seen_remotes = await self.store.have_seen_events( - event.room_id, [e.event_id for e in remote_auth_chain] - ) - - for e in remote_auth_chain: - if e.event_id in seen_remotes: - continue - - if e.event_id == event.event_id: - continue - - try: - auth_ids = e.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in remote_auth_chain - if e.event_id in auth_ids or e.type == EventTypes.Create - } - e.internal_metadata.outlier = True - - logger.debug( - "_check_event_auth %s missing_auth: %s", - event.event_id, - e.event_id, - ) - missing_auth_event_context = ( - await self.state_handler.compute_event_context(e) - ) - await self._auth_and_persist_event( - origin, - e, - missing_auth_event_context, - claimed_auth_event_map=auth, - ) - - if e.event_id in event_auth_events: - auth_events[(e.type, e.state_key)] = e - except AuthError: - pass - - except Exception: - logger.exception("Failed to get auth chain") - - if event.internal_metadata.is_outlier(): - # XXX: given that, for an outlier, we'll be working with the - # event's *claimed* auth events rather than those we calculated: - # (a) is there any point in this test, since different_auth below will - # obviously be empty - # (b) alternatively, why don't we do it earlier? - logger.info("Skipping auth_event fetch for outlier") - return context, auth_events - - different_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() - ) - - if not different_auth: - return context, auth_events - - logger.info( - "auth_events refers to events which are not in our calculated auth " - "chain: %s", - different_auth, - ) - - # XXX: currently this checks for redactions but I'm not convinced that is - # necessary? - different_events = await self.store.get_events_as_list(different_auth) - - for d in different_events: - if d.room_id != event.room_id: - logger.warning( - "Event %s refers to auth_event %s which is in a different room", - event.event_id, - d.event_id, - ) - - # don't attempt to resolve the claimed auth events against our own - # in this case: just use our own auth events. - # - # XXX: should we reject the event in this case? It feels like we should, - # but then shouldn't we also do so if we've failed to fetch any of the - # auth events? - return context, auth_events - - # now we state-resolve between our own idea of the auth events, and the remote's - # idea of them. - - local_state = auth_events.values() - remote_auth_events = dict(auth_events) - remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) - remote_state = remote_auth_events.values() - - room_version = await self.store.get_room_version_id(event.room_id) - new_state = await self.state_handler.resolve_events( - room_version, (local_state, remote_state), event - ) - - logger.info( - "After state res: updating auth_events with new state %s", - { - (d.type, d.state_key): d.event_id - for d in new_state.values() - if auth_events.get((d.type, d.state_key)) != d - }, - ) - - auth_events.update(new_state) - - context = await self._update_context_for_auth_events( - event, context, auth_events - ) - - return context, auth_events - - async def _update_context_for_auth_events( - self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] - ) -> EventContext: - """Update the state_ids in an event context after auth event resolution, - storing the changes as a new state group. - - Args: - event: The event we're handling the context for - - context: initial event context - - auth_events: Events to update in the event context. - - Returns: - new event context - """ - # exclude the state key of the new event from the current_state in the context. - if event.is_state(): - event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) - else: - event_key = None - state_updates = { - k: a.event_id for k, a in auth_events.items() if k != event_key - } - - current_state_ids = await context.get_current_state_ids() - current_state_ids = dict(current_state_ids) # type: ignore - - current_state_ids.update(state_updates) - - prev_state_ids = await context.get_prev_state_ids() - prev_state_ids = dict(prev_state_ids) - - prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) - - # create a new state group as a delta from the existing one. - prev_group = context.state_group - state_group = await self.state_store.store_state_group( - event.event_id, - event.room_id, - prev_group=prev_group, - delta_ids=state_updates, - current_state_ids=current_state_ids, - ) - - return EventContext.with_state( - state_group=state_group, - state_group_before_event=context.state_group_before_event, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, - prev_group=prev_group, - delta_ids=state_updates, - ) - async def construct_auth_difference( self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] ) -> Dict: @@ -3256,99 +1606,6 @@ async def _check_key_revocation(self, public_key: str, url: str) -> None: if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") - async def persist_events_and_notify( - self, - room_id: str, - event_and_contexts: Sequence[Tuple[EventBase, EventContext]], - backfilled: bool = False, - ) -> int: - """Persists events and tells the notifier/pushers about them, if - necessary. - - Args: - room_id: The room ID of events being persisted. - event_and_contexts: Sequence of events with their associated - context that should be persisted. All events must belong to - the same room. - backfilled: Whether these events are a result of - backfilling or not - - Returns: - The stream ID after which all events have been persisted. - """ - if not event_and_contexts: - return self.store.get_current_events_token() - - instance = self.config.worker.events_shard_config.get_instance(room_id) - if instance != self._instance_name: - # Limit the number of events sent over replication. We choose 200 - # here as that is what we default to in `max_request_body_size(..)` - for batch in batch_iter(event_and_contexts, 200): - result = await self._send_events( - instance_name=instance, - store=self.store, - room_id=room_id, - event_and_contexts=batch, - backfilled=backfilled, - ) - return result["max_stream_id"] - else: - assert self.storage.persistence - - # Note that this returns the events that were persisted, which may not be - # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self.storage.persistence.persist_events( - event_and_contexts, backfilled=backfilled - ) - - if self._ephemeral_messages_enabled: - for event in events: - # If there's an expiry timestamp on the event, schedule its expiry. - self._message_handler.maybe_schedule_expiry(event) - - if not backfilled: # Never notify for backfilled events - for event in events: - await self._notify_persisted_event(event, max_stream_token) - - return max_stream_token.stream - - async def _notify_persisted_event( - self, event: EventBase, max_stream_token: RoomStreamToken - ) -> None: - """Checks to see if notifier/pushers should be notified about the - event or not. - - Args: - event: - max_stream_id: The max_stream_id returned by persist_events - """ - - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - - # We notify for memberships if its an invite for one of our - # users - if event.internal_metadata.is_outlier(): - if event.membership != Membership.INVITE: - if not self.is_mine_id(target_user_id): - return - - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - elif event.internal_metadata.is_outlier(): - return - - # the event has been persisted so it should have a stream ordering. - assert event.internal_metadata.stream_ordering - - event_pos = PersistedEventPosition( - self._instance_name, event.internal_metadata.stream_ordering - ) - self.notifier.on_new_room_event( - event, event_pos, max_stream_token, extra_users=extra_users - ) - async def _clean_room_for_join(self, room_id: str) -> None: """Called to clean up any data in DB for a given room, ready for the server to join the room. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py new file mode 100644 index 000000000000..9f055f00cff3 --- /dev/null +++ b/synapse/handlers/federation_event.py @@ -0,0 +1,1825 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http import HTTPStatus +from typing import ( + TYPE_CHECKING, + Collection, + Container, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) + +import attr +from prometheus_client import Counter + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RejectedReason, + RoomEncryptionAlgorithms, +) +from synapse.api.errors import ( + AuthError, + Codes, + FederationError, + HttpResponseException, + RequestSendFailed, + SynapseError, +) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.event_auth import auth_types_for_event +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.federation.federation_client import InvalidResponseError +from synapse.handlers._base import BaseHandler +from synapse.logging.context import ( + make_deferred_yieldable, + nested_logging_context, + run_in_background, +) +from synapse.logging.utils import log_function +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet +from synapse.replication.http.federation import ( + ReplicationFederationSendEventsRestServlet, +) +from synapse.state import StateResolutionStore +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.types import ( + MutableStateMap, + PersistedEventPosition, + RoomStreamToken, + StateMap, + UserID, + get_domain_from_id, +) +from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.iterutils import batch_iter +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + +soft_failed_event_counter = Counter( + "synapse_federation_soft_failed_events_total", + "Events received over federation that we marked as soft_failed", +) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _NewEventInfo: + """Holds information about a received event, ready for passing to _auth_and_persist_events + + Attributes: + event: the received event + + claimed_auth_event_map: a map of (type, state_key) => event for the event's + claimed auth_events. + + This can include events which have not yet been persisted, in the case that + we are backfilling a batch of events. + + Note: May be incomplete: if we were unable to find all of the claimed auth + events. Also, treat the contents with caution: the events might also have + been rejected, might not yet have been authorized themselves, or they might + be in the wrong room. + + """ + + event: EventBase + claimed_auth_event_map: StateMap[EventBase] + + +class FederationEventHandler(BaseHandler): + """Handles events that originated from federation. + + Responsible for handing incoming events and passing them on to the rest + of the homeserver (including auth and state conflict resolutions) + """ + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state + + self.state_handler = hs.get_state_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self._event_auth_handler = hs.get_event_auth_handler() + self._message_handler = hs.get_message_handler() + self.action_generator = hs.get_action_generator() + self._state_resolution_handler = hs.get_state_resolution_handler() + + self.federation_client = hs.get_federation_client() + self.third_party_event_rules = hs.get_third_party_event_rules() + + self.is_mine_id = hs.is_mine_id + self._instance_name = hs.get_instance_name() + + self.config = hs.config + self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages + + self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) + if hs.config.worker_app: + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + else: + self._device_list_updater = hs.get_device_handler().device_list_updater + + # When joining a room we need to queue any events for that room up. + # For each room, a list of (pdu, origin) tuples. + # TODO: replace this with something more elegant, probably based around the + # federation event staging area. + self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {} + + self._room_pdu_linearizer = Linearizer("fed_room_pdu") + + async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: + """Process a PDU received via a federation /send/ transaction + + Args: + origin: server which initiated the /send/ transaction. Will + be used to fetch missing events or state. + pdu: received PDU + """ + + room_id = pdu.room_id + event_id = pdu.event_id + + # We reprocess pdus when we have seen them only as outliers + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + + # FIXME: Currently we fetch an event again when we already have it + # if it has been marked as an outlier. + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", event_id + ) + return + if pdu.internal_metadata.is_outlier(): + logger.info( + "Ignoring received outlier %s which we already have as an outlier", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + # do some initial sanity-checking of the event. In particular, make + # sure it doesn't have hundreds of prev_events or auth_events, which + # could cause a huge state resolution or cascade of event fetches. + try: + self._sanity_check_event(pdu) + except SynapseError as err: + logger.warning("Received event failed sanity checks") + raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) + + # If we are currently in the process of joining this room, then we + # queue up events for later processing. + if room_id in self.room_queues: + logger.info( + "Queuing PDU from %s for now: join in progress", + origin, + ) + self.room_queues[room_id].append((pdu, origin)) + return + + # If we're not in the room just ditch the event entirely. This is + # probably an old server that has come back and thinks we're still in + # the room (or we've been rejoined to the room by a state reset). + # + # Note that if we were never in the room then we would have already + # dropped the event, since we wouldn't know the room version. + is_in_room = await self._event_auth_handler.check_host_in_room( + room_id, self.server_name + ) + if not is_in_room: + logger.info( + "Ignoring PDU from %s as we're not in the room", + origin, + ) + return None + + # Check that the event passes auth based on the state at the event. This is + # done for events that are to be added to the timeline (non-outliers). + # + # Get missing pdus if necessary: + # - Fetching any missing prev events to fill in gaps in the graph + # - Fetching state if we have a hole in the graph + if not pdu.internal_metadata.is_outlier(): + prevs = set(pdu.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if missing_prevs: + # We only backfill backwards to the min depth. + min_depth = await self.get_min_depth_for_context(pdu.room_id) + logger.debug("min_depth: %d", min_depth) + + if min_depth is not None and pdu.depth > min_depth: + # If we're missing stuff, ensure we only fetch stuff one + # at a time. + logger.info( + "Acquiring room lock to fetch %d missing prev_events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + with (await self._room_pdu_linearizer.queue(pdu.room_id)): + logger.info( + "Acquired room lock to fetch %d missing prev_events", + len(missing_prevs), + ) + + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) from e + + # Update the set of things we've seen after trying to + # fetch the missing stuff + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + logger.info("Found all missing prev_events") + + if missing_prevs: + # since this event was pushed to us, it is possible for it to + # become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very + # bad. Further, if the event was pushed to us, there is no excuse + # for us not to have all the prev_events. (XXX: apart from + # min_depth?) + # + # We therefore reject any such events. + logger.warning( + "Rejecting: failed to fetch %d prev events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) + + await self._process_received_pdu(origin, pdu, state=None) + + @log_function + async def on_send_membership_event( + self, origin: str, event: EventBase + ) -> Tuple[EventBase, EventContext]: + """ + We have received a join/leave/knock event for a room via send_join/leave/knock. + + Verify that event and send it into the room on the remote homeserver's behalf. + + This is quite similar to on_receive_pdu, with the following principal + differences: + * only membership events are permitted (and only events with + sender==state_key -- ie, no kicks or bans) + * *We* send out the event on behalf of the remote server. + * We enforce the membership restrictions of restricted rooms. + * Rejected events result in an exception rather than being stored. + + There are also other differences, however it is not clear if these are by + design or omission. In particular, we do not attempt to backfill any missing + prev_events. + + Args: + origin: The homeserver of the remote (joining/invited/knocking) user. + event: The member event that has been signed by the remote homeserver. + + Returns: + The event and context of the event after inserting it into the room graph. + + Raises: + SynapseError if the event is not accepted into the room + """ + logger.debug( + "on_send_membership_event: Got event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got send_membership request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + if event.sender != event.state_key: + raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON) + + assert not event.internal_metadata.outlier + + # Send this event on behalf of the other server. + # + # The remote server isn't a full participant in the room at this point, so + # may not have an up-to-date list of the other homeservers participating in + # the room, so we send it on their behalf. + event.internal_metadata.send_on_behalf_of = origin + + context = await self.state_handler.compute_event_context(event) + context = await self._check_event_auth(origin, event, context) + if context.rejected: + raise SynapseError( + 403, f"{event.membership} event was rejected", Codes.FORBIDDEN + ) + + # for joins, we need to check the restrictions of restricted rooms + if event.membership == Membership.JOIN: + await self.check_join_restrictions(context, event) + + # for knock events, we run the third-party event rules. It's not entirely clear + # why we don't do this for other sorts of membership events. + if event.membership == Membership.KNOCK: + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Sending of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + # all looks good, we can persist the event. + await self._run_push_actions_and_persist_event(event, context) + return event, context + + async def check_join_restrictions( + self, context: EventContext, event: EventBase + ) -> None: + """Check that restrictions in restricted join rules are matched + + Called when we receive a join event via send_join. + + Raises an auth error if the restrictions are not matched. + """ + prev_state_ids = await context.get_prev_state_ids() + + # Check if the user is already in the room or invited to the room. + user_id = event.state_key + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + prev_member_event = None + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + + # Check if the member should be allowed access via membership in a space. + await self._event_auth_handler.check_restricted_join_rules( + prev_state_ids, + event.room_version, + user_id, + prev_member_event, + ) + + @log_function + async def backfill( + self, dest: str, room_id: str, limit: int, extremities: List[str] + ) -> None: + """Trigger a backfill request to `dest` for the given `room_id` + + This will attempt to get more events from the remote. If the other side + has no new events to offer, this will return an empty list. + + As the events are received, we check their signatures, and also do some + sanity-checking on them. If any of the backfilled events are invalid, + this method throws a SynapseError. + + We might also raise an InvalidResponseError if the response from the remote + server is just bogus. + + TODO: make this more useful to distinguish failures of the remote + server from invalid events (there is probably no point in trying to + re-fetch invalid events from every other HS in the room.) + """ + if dest == self.server_name: + raise SynapseError(400, "Can't backfill from self.") + + events = await self.federation_client.backfill( + dest, room_id, limit=limit, extremities=extremities + ) + + if not events: + return + + # if there are any events in the wrong room, the remote server is buggy and + # should not be trusted. + for ev in events: + if ev.room_id != room_id: + raise InvalidResponseError( + f"Remote server {dest} returned event {ev.event_id} which is in " + f"room {ev.room_id}, when we were backfilling in {room_id}" + ) + + await self._process_pulled_events(dest, events, backfilled=True) + + async def _get_missing_events_for_pdu( + self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int + ) -> None: + """ + Args: + origin: Origin of the pdu. Will be called to get the missing events + pdu: received pdu + prevs: List of event ids which we are missing + min_depth: Minimum depth of events to return. + """ + + room_id = pdu.room_id + event_id = pdu.event_id + + seen = await self.store.have_events_in_timeline(prevs) + + if not prevs - seen: + return + + latest_list = await self.store.get_latest_event_ids_in_room(room_id) + + # We add the prev events that we have seen to the latest + # list to ensure the remote server doesn't give them to us + latest = set(latest_list) + latest |= seen + + logger.info( + "Requesting missing events between %s and %s", + shortstr(latest), + event_id, + ) + + # XXX: we set timeout to 10s to help workaround + # https://github.com/matrix-org/synapse/issues/1733. + # The reason is to avoid holding the linearizer lock + # whilst processing inbound /send transactions, causing + # FDs to stack up and block other inbound transactions + # which empirically can currently take up to 30 minutes. + # + # N.B. this explicitly disables retry attempts. + # + # N.B. this also increases our chances of falling back to + # fetching fresh state for the room if the missing event + # can't be found, which slightly reduces our security. + # it may also increase our DAG extremity count for the room, + # causing additional state resolution? See #1760. + # However, fetching state doesn't hold the linearizer lock + # apparently. + # + # see https://github.com/matrix-org/synapse/pull/1744 + # + # ---- + # + # Update richvdh 2018/09/18: There are a number of problems with timing this + # request out aggressively on the client side: + # + # - it plays badly with the server-side rate-limiter, which starts tarpitting you + # if you send too many requests at once, so you end up with the server carefully + # working through the backlog of your requests, which you have already timed + # out. + # + # - for this request in particular, we now (as of + # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the + # server can't produce a plausible-looking set of prev_events - so we becone + # much more likely to reject the event. + # + # - contrary to what it says above, we do *not* fall back to fetching fresh state + # for the room if get_missing_events times out. Rather, we give up processing + # the PDU whose prevs we are missing, which then makes it much more likely that + # we'll end up back here for the *next* PDU in the list, which exacerbates the + # problem. + # + # - the aggressive 10s timeout was introduced to deal with incoming federation + # requests taking 8 hours to process. It's not entirely clear why that was going + # on; certainly there were other issues causing traffic storms which are now + # resolved, and I think in any case we may be more sensible about our locking + # now. We're *certainly* more sensible about our logging. + # + # All that said: Let's try increasing the timeout to 60s and see what happens. + + try: + missing_events = await self.federation_client.get_missing_events( + origin, + room_id, + earliest_events_ids=list(latest), + latest_events=[pdu], + limit=10, + min_depth=min_depth, + timeout=60000, + ) + except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e: + # We failed to get the missing events, but since we need to handle + # the case of `get_missing_events` not returning the necessary + # events anyway, it is safe to simply log the error and continue. + logger.warning("Failed to get prev_events: %s", e) + return + + logger.info("Got %d prev_events", len(missing_events)) + await self._process_pulled_events(origin, missing_events, backfilled=False) + + async def _process_pulled_events( + self, origin: str, events: Iterable[EventBase], backfilled: bool + ) -> None: + """Process a batch of events we have pulled from a remote server + + Pulls in any events required to auth the events, persists the received events, + and notifies clients, if appropriate. + + Assumes the events have already had their signatures and hashes checked. + + Params: + origin: The server we received these events from + events: The received events. + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) + """ + + # We want to sort these by depth so we process them and + # tell clients about them in order. + sorted_events = sorted(events, key=lambda x: x.depth) + + for ev in sorted_events: + with nested_logging_context(ev.event_id): + await self._process_pulled_event(origin, ev, backfilled=backfilled) + + async def _process_pulled_event( + self, origin: str, event: EventBase, backfilled: bool + ) -> None: + """Process a single event that we have pulled from a remote server + + Pulls in any events required to auth the event, persists the received event, + and notifies clients, if appropriate. + + Assumes the event has already had its signatures and hashes checked. + + This is somewhat equivalent to on_receive_pdu, but applies somewhat different + logic in the case that we are missing prev_events (in particular, it just + requests the state at that point, rather than triggering a get_missing_events) - + so is appropriate when we have pulled the event from a remote server, rather + than having it pushed to us. + + Params: + origin: The server we received this event from + events: The received event + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) + """ + logger.info("Processing pulled event %s", event) + + # these should not be outliers. + assert not event.internal_metadata.is_outlier() + + event_id = event.event_id + + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + try: + self._sanity_check_event(event) + except SynapseError as err: + logger.warning("Event %s failed sanity check: %s", event_id, err) + return + + try: + state = await self._resolve_state_at_missing_prevs(origin, event) + await self._process_received_pdu( + origin, event, state=state, backfilled=backfilled + ) + except FederationError as e: + if e.code == 403: + logger.warning("Pulled event %s failed history check.", event_id) + else: + raise + + async def _resolve_state_at_missing_prevs( + self, dest: str, event: EventBase + ) -> Optional[Iterable[EventBase]]: + """Calculate the state at an event with missing prev_events. + + This is used when we have pulled a batch of events from a remote server, and + still don't have all the prev_events. + + If we already have all the prev_events for `event`, this method does nothing. + + Otherwise, the missing prevs become new backwards extremities, and we fall back + to asking the remote server for the state after each missing `prev_event`, + and resolving across them. + + That's ok provided we then resolve the state against other bits of the DAG + before using it - in other words, that the received event `event` is not going + to become the only forwards_extremity in the room (which will ensure that you + can't just take over a room by sending an event, withholding its prev_events, + and declaring yourself to be an admin in the subsequent state request). + + In other words: we should only call this method if `event` has been *pulled* + as part of a batch of missing prev events, or similar. + + Params: + dest: the remote server to ask for state at the missing prevs. Typically, + this will be the server we got `event` from. + event: an event to check for missing prevs. + + Returns: + if we already had all the prev events, `None`. Otherwise, returns a list of + the events in the state at `event`. + """ + room_id = event.room_id + event_id = event.event_id + + prevs = set(event.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + return None + + logger.info( + "Event %s is missing prev_events %s: calculating state for a " + "backwards extremity", + event_id, + shortstr(missing_prevs), + ) + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. + event_map = {event_id: event} + try: + # Get the state of the events we know about + ours = await self.state_store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + state_maps: List[StateMap[str]] = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours + + # Ask the remote server for the states we don't + # know about + for p in missing_prevs: + logger.info("Requesting state after missing prev_event %s", p) + + with nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state = await self._get_state_after_missing_prev_event( + dest, room_id, p + ) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x + + room_version = await self.store.get_room_version_id(room_id) + state_map = await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) + + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = await self.store.get_events( + list(state_map.values()), + get_prev_content=False, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + event_map.update(evs) + + state = [event_map[e] for e in state_map.values()] + except Exception: + logger.warning( + "Error attempting to resolve state at missing prev_events", + exc_info=True, + ) + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=event_id, + ) + return state + + async def _get_state_after_missing_prev_event( + self, + destination: str, + room_id: str, + event_id: str, + ) -> List[EventBase]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + + Returns: + A list of events in the state, including the event itself + """ + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + logger.debug( + "state_ids returned %i state events, %i auth events", + len(state_event_ids), + len(auth_event_ids), + ) + + # start by just trying to fetch the events from the store + desired_events = set(state_event_ids) + desired_events.add(event_id) + logger.debug("Fetching %i events from cache/store", len(desired_events)) + fetched_events = await self.store.get_events( + desired_events, allow_rejected=True + ) + + missing_desired_events = desired_events - fetched_events.keys() + logger.debug( + "We are missing %i events (got %i)", + len(missing_desired_events), + len(fetched_events), + ) + + # We probably won't need most of the auth events, so let's just check which + # we have for now, rather than thrashing the event cache with them all + # unnecessarily. + + # TODO: we probably won't actually need all of the auth events, since we + # already have a bunch of the state events. It would be nice if the + # federation api gave us a way of finding out which we actually need. + + missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events.difference_update( + await self.store.have_seen_events(room_id, missing_auth_events) + ) + logger.debug("We are also missing %i auth events", len(missing_auth_events)) + + missing_events = missing_desired_events | missing_auth_events + logger.debug("Fetching %i events from remote", len(missing_events)) + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + await self.store.get_events(missing_desired_events, allow_rejected=True) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + del fetched_events[bad_event_id] + + # if we couldn't get the prev event in question, that's a problem. + remote_event = fetched_events.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + + # missing state at that event is a warning, not a blocker + # XXX: this doesn't sound right? it means that we'll end up with incomplete + # state. + failed_to_fetch = desired_events - fetched_events.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state events for %s %s", + event_id, + failed_to_fetch, + ) + + remote_state = [ + fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events + ] + + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) + + return remote_state + + async def _process_received_pdu( + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], + backfilled: bool = False, + ) -> None: + """Called when we have a new pdu. We need to do auth checks and put it + through the StateHandler. + + Args: + origin: server sending the event + + event: event to be persisted + + state: Normally None, but if we are handling a gap in the graph + (ie, we are missing one or more prev_events), the resolved state at the + event + + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) + """ + logger.debug("Processing event: %s", event) + + try: + context = await self.state_handler.compute_event_context( + event, old_state=state + ) + await self._auth_and_persist_event( + origin, event, context, state=state, backfilled=backfilled + ) + except AuthError as e: + raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) + + if backfilled: + return + + # For encrypted messages we check that we know about the sending device, + # if we don't then we mark the device cache for that user as stale. + if event.type == EventTypes.Encrypted: + device_id = event.content.get("device_id") + sender_key = event.content.get("sender_key") + + cached_devices = await self.store.get_cached_devices_for_user(event.sender) + + resync = False # Whether we should resync device lists. + + device = None + if device_id is not None: + device = cached_devices.get(device_id) + if device is None: + logger.info( + "Received event from remote device not in our cache: %s %s", + event.sender, + device_id, + ) + resync = True + + # We also check if the `sender_key` matches what we expect. + if sender_key is not None: + # Figure out what sender key we're expecting. If we know the + # device and recognize the algorithm then we can work out the + # exact key to expect. Otherwise check it matches any key we + # have for that device. + + current_keys: Container[str] = [] + + if device: + keys = device.get("keys", {}).get("keys", {}) + + if ( + event.content.get("algorithm") + == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2 + ): + # For this algorithm we expect a curve25519 key. + key_name = "curve25519:%s" % (device_id,) + current_keys = [keys.get(key_name)] + else: + # We don't know understand the algorithm, so we just + # check it matches a key for the device. + current_keys = keys.values() + elif device_id: + # We don't have any keys for the device ID. + pass + else: + # The event didn't include a device ID, so we just look for + # keys across all devices. + current_keys = [ + key + for device in cached_devices.values() + for key in device.get("keys", {}).get("keys", {}).values() + ] + + # We now check that the sender key matches (one of) the expected + # keys. + if sender_key not in current_keys: + logger.info( + "Received event from remote device with unexpected sender key: %s %s: %s", + event.sender, + device_id or "", + sender_key, + ) + resync = True + + if resync: + run_as_background_process( + "resync_device_due_to_pdu", + self._resync_device, + event.sender, + ) + + await self._handle_marker_event(origin, event) + + async def _resync_device(self, sender: str) -> None: + """We have detected that the device list for the given user may be out + of sync, so we try and resync them. + """ + + try: + await self.store.mark_remote_user_device_cache_as_stale(sender) + + # Immediately attempt a resync in the background + if self.config.worker_app: + await self._user_device_resync(user_id=sender) + else: + await self._device_list_updater.user_device_resync(sender) + except Exception: + logger.exception("Failed to resync device for %s", sender) + + async def _handle_marker_event(self, origin: str, marker_event: EventBase): + """Handles backfilling the insertion event when we receive a marker + event that points to one. + + Args: + origin: Origin of the event. Will be called to get the insertion event + marker_event: The event to process + """ + + if marker_event.type != EventTypes.MSC2716_MARKER: + # Not a marker event + return + + if marker_event.rejected_reason is not None: + # Rejected event + return + + # Skip processing a marker event if the room version doesn't + # support it. + room_version = await self.store.get_room_version(marker_event.room_id) + if not room_version.msc2716_historical: + return + + logger.debug("_handle_marker_event: received %s", marker_event) + + insertion_event_id = marker_event.content.get( + EventContentFields.MSC2716_MARKER_INSERTION + ) + + if insertion_event_id is None: + # Nothing to retrieve then (invalid marker) + return + + logger.debug( + "_handle_marker_event: backfilling insertion event %s", insertion_event_id + ) + + await self._get_events_and_persist( + origin, + marker_event.room_id, + [insertion_event_id], + ) + + insertion_event = await self.store.get_event( + insertion_event_id, allow_none=True + ) + if insertion_event is None: + logger.warning( + "_handle_marker_event: server %s didn't return insertion event %s for marker %s", + origin, + insertion_event_id, + marker_event.event_id, + ) + return + + logger.debug( + "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", + insertion_event, + marker_event, + ) + + await self.store.insert_insertion_extremity( + insertion_event_id, marker_event.room_id + ) + + logger.debug( + "_handle_marker_event: insertion extremity added for %s from marker event %s", + insertion_event, + marker_event, + ) + + async def _get_events_and_persist( + self, destination: str, room_id: str, events: Iterable[str] + ) -> None: + """Fetch the given events from a server, and persist them as outliers. + + This function *does not* recursively get missing auth events of the + newly fetched events. Callers must include in the `events` argument + any missing events from the auth chain. + + Logs a warning if we can't find the given event. + """ + + room_version = await self.store.get_room_version(room_id) + + event_map: Dict[str, EventBase] = {} + + async def get_event(event_id: str): + with nested_logging_context(event_id): + try: + event = await self.federation_client.get_pdu( + [destination], + event_id, + room_version, + outlier=True, + ) + if event is None: + logger.warning( + "Server %s didn't return event %s", + destination, + event_id, + ) + return + + event_map[event.event_id] = event + + except Exception as e: + logger.warning( + "Error fetching missing state/auth event %s: %s %s", + event_id, + type(e), + e, + ) + + await concurrently_execute(get_event, events, 5) + + # Make a map of auth events for each event. We do this after fetching + # all the events as some of the events' auth events will be in the list + # of requested events. + + auth_events = [ + aid + for event in event_map.values() + for aid in event.auth_event_ids() + if aid not in event_map + ] + persisted_events = await self.store.get_events( + auth_events, + allow_rejected=True, + ) + + event_infos = [] + for event in event_map.values(): + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + else: + logger.info("Missing auth event %s", auth_event_id) + + event_infos.append(_NewEventInfo(event, auth)) + + if event_infos: + await self._auth_and_persist_events( + destination, + room_id, + event_infos, + ) + + async def _auth_and_persist_events( + self, + origin: str, + room_id: str, + event_infos: Collection[_NewEventInfo], + ) -> None: + """Creates the appropriate contexts and persists events. The events + should not depend on one another, e.g. this should be used to persist + a bunch of outliers, but not a chunk of individual events that depend + on each other for state calculations. + + Notifies about the events where appropriate. + """ + + if not event_infos: + return + + async def prep(ev_info: _NewEventInfo): + event = ev_info.event + with nested_logging_context(suffix=event.event_id): + res = await self.state_handler.compute_event_context(event) + res = await self._check_event_auth( + origin, + event, + res, + claimed_auth_event_map=ev_info.claimed_auth_event_map, + ) + return res + + contexts = await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(prep, ev_info) for ev_info in event_infos], + consumeErrors=True, + ) + ) + + await self.persist_events_and_notify( + room_id, + [ + (ev_info.event, context) + for ev_info, context in zip(event_infos, contexts) + ], + ) + + async def _auth_and_persist_event( + self, + origin: str, + event: EventBase, + context: EventContext, + state: Optional[Iterable[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, + backfilled: bool = False, + ) -> None: + """ + Process an event by performing auth checks and then persisting to the database. + + Args: + origin: The host the event originates from. + event: The event itself. + context: + The event context. + + state: + The state events used to check the event for soft-fail. If this is + not provided the current state events will be used. + + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers. + + backfilled: True if the event was backfilled. + """ + context = await self._check_event_auth( + origin, + event, + context, + state=state, + claimed_auth_event_map=claimed_auth_event_map, + backfilled=backfilled, + ) + + await self._run_push_actions_and_persist_event(event, context, backfilled) + + async def _check_event_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + state: Optional[Iterable[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, + backfilled: bool = False, + ) -> EventContext: + """ + Checks whether an event should be rejected (for failing auth checks). + + Args: + origin: The host the event originates from. + event: The event itself. + context: + The event context. + + state: + The state events used to check the event for soft-fail. If this is + not provided the current state events will be used. + + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers, or the state for a backwards + extremity. + + backfilled: True if the event was backfilled. + + Returns: + The updated context object. + """ + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + if claimed_auth_event_map: + # if we have a copy of the auth events from the event, use that as the + # basis for auth. + auth_events = claimed_auth_event_map + else: + # otherwise, we calculate what the auth events *should* be, and use that + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self._event_auth_handler.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + + try: + ( + context, + auth_events_for_auth, + ) = await self._update_auth_events_and_context_for_auth( + origin, event, context, auth_events + ) + except Exception: + # We don't really mind if the above fails, so lets not fail + # processing if it does. However, it really shouldn't fail so + # let's still log as an exception since we'll still want to fix + # any bugs. + logger.exception( + "Failed to double check auth events for %s with remote. " + "Ignoring failure and continuing processing of event.", + event.event_id, + ) + auth_events_for_auth = auth_events + + try: + event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) + except AuthError as e: + logger.warning("Failed auth resolution for %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + + if not context.rejected: + await self._check_for_soft_fail(event, state, backfilled, origin=origin) + + if event.type == EventTypes.GuestAccess and not context.rejected: + await self.maybe_kick_guest_users(event) + + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event( + event, context + ) + + return context + + async def _check_for_soft_fail( + self, + event: EventBase, + state: Optional[Iterable[EventBase]], + backfilled: bool, + origin: str, + ) -> None: + """Checks if we should soft fail the event; if so, marks the event as + such. + + Args: + event + state: The state at the event if we don't have all the event's prev events + backfilled: Whether the event is from backfill + origin: The host the event originates from. + """ + # For new (non-backfilled and non-outlier) events we check if the event + # passes auth based on the current state. If it doesn't then we + # "soft-fail" the event. + if backfilled or event.internal_metadata.is_outlier(): + return + + extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids_list) + prev_event_ids = set(event.prev_event_ids()) + + if extrem_ids == prev_event_ids: + # If they're the same then the current state is the same as the + # state at the event, so no point rechecking auth for soft fail. + return + + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + # Calculate the "current state". + if state is not None: + # If we're explicitly given the state then we won't have all the + # prev events, and so we have a gap in the graph. In this case + # we want to be a little careful as we might have been down for + # a while and have an incorrect view of the current state, + # however we still want to do checks as gaps are easy to + # maliciously manufacture. + # + # So we use a "current state" that is actually a state + # resolution across the current forward extremities and the + # given state at the event. This should correctly handle cases + # like bans, especially with state res v2. + + state_sets_d = await self.state_store.get_state_groups( + event.room_id, extrem_ids + ) + state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) + state_sets.append(state) + current_states = await self.state_handler.resolve_events( + room_version, state_sets, event + ) + current_state_ids: StateMap[str] = { + k: e.event_id for k, e in current_states.items() + } + else: + current_state_ids = await self.state_handler.get_current_state_ids( + event.room_id, latest_event_ids=extrem_ids + ) + + logger.debug( + "Doing soft-fail check for %s: state %s", + event.event_id, + current_state_ids, + ) + + # Now check if event pass auth against said current state + auth_types = auth_types_for_event(room_version_obj, event) + current_state_ids_list = [ + e for k, e in current_state_ids.items() if k in auth_types + ] + + auth_events_map = await self.store.get_events(current_state_ids_list) + current_auth_events = { + (e.type, e.state_key): e for e in auth_events_map.values() + } + + try: + event_auth.check(room_version_obj, event, auth_events=current_auth_events) + except AuthError as e: + logger.warning( + "Soft-failing %r (from %s) because %s", + event, + e, + origin, + extra={ + "room_id": event.room_id, + "mxid": event.sender, + "hs": origin, + }, + ) + soft_failed_event_counter.inc() + event.internal_metadata.soft_failed = True + + async def _update_auth_events_and_context_for_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + input_auth_events: StateMap[EventBase], + ) -> Tuple[EventContext, StateMap[EventBase]]: + """Helper for _check_event_auth. See there for docs. + + Checks whether a given event has the expected auth events. If it + doesn't then we talk to the remote server to compare state to see if + we can come to a consensus (e.g. if one server missed some valid + state). + + This attempts to resolve any potential divergence of state between + servers, but is not essential and so failures should not block further + processing of the event. + + Args: + origin: + event: + context: + + input_auth_events: + Map from (event_type, state_key) to event + + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. + + Returns: + updated context, updated auth event map + """ + # take a copy of input_auth_events before we modify it. + auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + + event_auth_events = set(event.auth_event_ids()) + + # missing_auth is the set of the event's auth_events which we don't yet have + # in auth_events. + missing_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + # if we have missing events, we need to fetch those events from somewhere. + # + # we start by checking if they are in the store, and then try calling /event_auth/. + if missing_auth: + have_events = await self.store.have_seen_events(event.room_id, missing_auth) + logger.debug("Events %s are in the store", have_events) + missing_auth.difference_update(have_events) + + if missing_auth: + # If we don't have all the auth events, we need to get them. + logger.info("auth_events contains unknown events: %s", missing_auth) + try: + try: + remote_auth_chain = await self.federation_client.get_event_auth( + origin, event.room_id, event.event_id + ) + except RequestSendFailed as e1: + # The other side isn't around or doesn't implement the + # endpoint, so lets just bail out. + logger.info("Failed to get event auth from remote: %s", e1) + return context, auth_events + + seen_remotes = await self.store.have_seen_events( + event.room_id, [e.event_id for e in remote_auth_chain] + ) + + for e in remote_auth_chain: + if e.event_id in seen_remotes: + continue + + if e.event_id == event.event_id: + continue + + try: + auth_ids = e.auth_event_ids() + auth = { + (e.type, e.state_key): e + for e in remote_auth_chain + if e.event_id in auth_ids or e.type == EventTypes.Create + } + e.internal_metadata.outlier = True + + logger.debug( + "_check_event_auth %s missing_auth: %s", + event.event_id, + e.event_id, + ) + missing_auth_event_context = ( + await self.state_handler.compute_event_context(e) + ) + await self._auth_and_persist_event( + origin, + e, + missing_auth_event_context, + claimed_auth_event_map=auth, + ) + + if e.event_id in event_auth_events: + auth_events[(e.type, e.state_key)] = e + except AuthError: + pass + + except Exception: + logger.exception("Failed to get auth chain") + + if event.internal_metadata.is_outlier(): + # XXX: given that, for an outlier, we'll be working with the + # event's *claimed* auth events rather than those we calculated: + # (a) is there any point in this test, since different_auth below will + # obviously be empty + # (b) alternatively, why don't we do it earlier? + logger.info("Skipping auth_event fetch for outlier") + return context, auth_events + + different_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + if not different_auth: + return context, auth_events + + logger.info( + "auth_events refers to events which are not in our calculated auth " + "chain: %s", + different_auth, + ) + + # XXX: currently this checks for redactions but I'm not convinced that is + # necessary? + different_events = await self.store.get_events_as_list(different_auth) + + for d in different_events: + if d.room_id != event.room_id: + logger.warning( + "Event %s refers to auth_event %s which is in a different room", + event.event_id, + d.event_id, + ) + + # don't attempt to resolve the claimed auth events against our own + # in this case: just use our own auth events. + # + # XXX: should we reject the event in this case? It feels like we should, + # but then shouldn't we also do so if we've failed to fetch any of the + # auth events? + return context, auth_events + + # now we state-resolve between our own idea of the auth events, and the remote's + # idea of them. + + local_state = auth_events.values() + remote_auth_events = dict(auth_events) + remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) + remote_state = remote_auth_events.values() + + room_version = await self.store.get_room_version_id(event.room_id) + new_state = await self.state_handler.resolve_events( + room_version, (local_state, remote_state), event + ) + + logger.info( + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id + for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, + ) + + auth_events.update(new_state) + + context = await self._update_context_for_auth_events( + event, context, auth_events + ) + + return context, auth_events + + async def _update_context_for_auth_events( + self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] + ) -> EventContext: + """Update the state_ids in an event context after auth event resolution, + storing the changes as a new state group. + + Args: + event: The event we're handling the context for + + context: initial event context + + auth_events: Events to update in the event context. + + Returns: + new event context + """ + # exclude the state key of the new event from the current_state in the context. + if event.is_state(): + event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) + else: + event_key = None + state_updates = { + k: a.event_id for k, a in auth_events.items() if k != event_key + } + + current_state_ids = await context.get_current_state_ids() + current_state_ids = dict(current_state_ids) # type: ignore + + current_state_ids.update(state_updates) + + prev_state_ids = await context.get_prev_state_ids() + prev_state_ids = dict(prev_state_ids) + + prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) + + # create a new state group as a delta from the existing one. + prev_group = context.state_group + state_group = await self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=prev_group, + delta_ids=state_updates, + current_state_ids=current_state_ids, + ) + + return EventContext.with_state( + state_group=state_group, + state_group_before_event=context.state_group_before_event, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + prev_group=prev_group, + delta_ids=state_updates, + ) + + async def _run_push_actions_and_persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ): + """Run the push actions for a received event, and persist it. + + Args: + event: The event itself. + context: The event context. + backfilled: True if the event was backfilled. + """ + try: + if ( + not event.internal_metadata.is_outlier() + and not backfilled + and not context.rejected + and (await self.store.get_min_depth(event.room_id)) <= event.depth + ): + await self.action_generator.handle_push_actions_for_event( + event, context + ) + + await self.persist_events_and_notify( + event.room_id, [(event, context)], backfilled=backfilled + ) + except Exception: + run_in_background( + self.store.remove_push_actions_from_staging, event.event_id + ) + raise + + async def persist_events_and_notify( + self, + room_id: str, + event_and_contexts: Sequence[Tuple[EventBase, EventContext]], + backfilled: bool = False, + ) -> int: + """Persists events and tells the notifier/pushers about them, if + necessary. + + Args: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. + backfilled: Whether these events are a result of + backfilling or not + + Returns: + The stream ID after which all events have been persisted. + """ + if not event_and_contexts: + return self.store.get_current_events_token() + + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: + # Limit the number of events sent over replication. We choose 200 + # here as that is what we default to in `max_request_body_size(..)` + for batch in batch_iter(event_and_contexts, 200): + result = await self._send_events( + instance_name=instance, + store=self.store, + room_id=room_id, + event_and_contexts=batch, + backfilled=backfilled, + ) + return result["max_stream_id"] + else: + assert self.storage.persistence + + # Note that this returns the events that were persisted, which may not be + # the same as were passed in if some were deduplicated due to transaction IDs. + events, max_stream_token = await self.storage.persistence.persist_events( + event_and_contexts, backfilled=backfilled + ) + + if self._ephemeral_messages_enabled: + for event in events: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + + if not backfilled: # Never notify for backfilled events + for event in events: + await self._notify_persisted_event(event, max_stream_token) + + return max_stream_token.stream + + async def _notify_persisted_event( + self, event: EventBase, max_stream_token: RoomStreamToken + ) -> None: + """Checks to see if notifier/pushers should be notified about the + event or not. + + Args: + event: + max_stream_token: The max_stream_id returned by persist_events + """ + + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + + # We notify for memberships if its an invite for one of our + # users + if event.internal_metadata.is_outlier(): + if event.membership != Membership.INVITE: + if not self.is_mine_id(target_user_id): + return + + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) + elif event.internal_metadata.is_outlier(): + return + + # the event has been persisted so it should have a stream ordering. + assert event.internal_metadata.stream_ordering + + event_pos = PersistedEventPosition( + self._instance_name, event.internal_metadata.stream_ordering + ) + self.notifier.on_new_room_event( + event, event_pos, max_stream_token, extra_users=extra_users + ) + + def _sanity_check_event(self, ev: EventBase) -> None: + """ + Do some early sanity checks of a received event + + In particular, checks it doesn't have an excessive number of + prev_events or auth_events, which could cause a huge state resolution + or cascade of event fetches. + + Args: + ev: event to be checked + + Raises: + SynapseError if the event does not pass muster + """ + if len(ev.prev_event_ids()) > 20: + logger.warning( + "Rejecting event %s which has %i prev_events", + ev.event_id, + len(ev.prev_event_ids()), + ) + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") + + if len(ev.auth_event_ids()) > 10: + logger.warning( + "Rejecting event %s which has %i auth_events", + ev.event_id, + len(ev.auth_event_ids()), + ) + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") + + async def get_min_depth_for_context(self, context: str) -> int: + return await self.store.get_min_depth(context) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 79cadb7b574c..a0b3145f4e32 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -62,7 +62,7 @@ def __init__(self, hs): self.store = hs.get_datastore() self.storage = hs.get_storage() self.clock = hs.get_clock() - self.federation_handler = hs.get_federation_handler() + self.federation_event_handler = hs.get_federation_event_handler() @staticmethod async def _serialize_payload(store, room_id, event_and_contexts, backfilled): @@ -127,7 +127,7 @@ async def _handle_request(self, request): logger.info("Got %d events from federation", len(event_and_contexts)) - max_stream_id = await self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_event_handler.persist_events_and_notify( room_id, event_and_contexts, backfilled ) diff --git a/synapse/server.py b/synapse/server.py index de6517663e6b..5adeeff61a5f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -76,6 +76,7 @@ from synapse.handlers.event_auth import EventAuthHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.federation import FederationHandler +from synapse.handlers.federation_event import FederationEventHandler from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler from synapse.handlers.identity import IdentityHandler from synapse.handlers.initial_sync import InitialSyncHandler @@ -546,6 +547,10 @@ def get_event_stream_handler(self) -> EventStreamHandler: def get_federation_handler(self) -> FederationHandler: return FederationHandler(self) + @cache_in_self + def get_federation_event_handler(self) -> FederationEventHandler: + return FederationEventHandler(self) + @cache_in_self def get_identity_handler(self) -> IdentityHandler: return IdentityHandler(self) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 383214ab5046..663960ff534a 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -208,7 +208,7 @@ async def approve_all_signature_checking(_, pdu): async def _check_event_auth(origin, event, context, *args, **kwargs): return context - homeserver.get_federation_handler()._check_event_auth = _check_event_auth + homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth return super().prepare(reactor, clock, homeserver) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index c72a8972a3f1..6c67a16de923 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -130,7 +130,9 @@ def test_rejected_message_event_state(self): ) with LoggingContext("send_rejected"): - d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + d = run_in_background( + self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev + ) self.get_success(d) # that should have been rejected @@ -182,7 +184,9 @@ def test_rejected_state_event_state(self): ) with LoggingContext("send_rejected"): - d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + d = run_in_background( + self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev + ) self.get_success(d) # that should have been rejected @@ -311,7 +315,9 @@ async def get_event_auth( with LoggingContext("receive_pdu"): # Fake the OTHER_SERVER federating the message event over to our local homeserver d = run_in_background( - self.handler.on_receive_pdu, OTHER_SERVER, message_event + self.hs.get_federation_event_handler().on_receive_pdu, + OTHER_SERVER, + message_event, ) self.get_success(d) @@ -382,7 +388,9 @@ def _build_and_send_join_event(self, other_server, other_user, room_id): join_event.signatures[other_server] = {"x": "y"} with LoggingContext("send_join"): d = run_in_background( - self.handler.on_send_membership_event, other_server, join_event + self.hs.get_federation_event_handler().on_send_membership_event, + other_server, + join_event, ) self.get_success(d) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 0a52bc8b721f..671dc7d083c8 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -885,7 +885,7 @@ def default_config(self): def prepare(self, reactor, clock, hs): self.federation_sender = hs.get_federation_sender() self.event_builder_factory = hs.get_event_builder_factory() - self.federation_handler = hs.get_federation_handler() + self.federation_event_handler = hs.get_federation_event_handler() self.presence_handler = hs.get_presence_handler() # self.event_builder_for_2 = EventBuilderFactory(hs) @@ -1026,7 +1026,7 @@ def _add_new_user(self, room_id, user_id): builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) ) - self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) + self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) # Check that it was successfully persisted. self.get_success(self.store.get_event(event.event_id)) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index af5dfca752b9..92a5b53e11e7 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -205,7 +205,7 @@ def test_send_typing_sharded(self): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) store = self.hs.get_datastore() - federation = self.hs.get_federation_handler() + federation = self.hs.get_federation_event_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) room_version = self.get_success(store.get_room_version(room)) diff --git a/tests/test_federation.py b/tests/test_federation.py index 348fcb72a7e0..61c9d7c2ef96 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,7 +75,8 @@ def setUp(self): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( + federation_event_handler = self.homeserver.get_federation_event_handler() + federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( context ) self.client = self.homeserver.get_federation_client() @@ -85,7 +86,9 @@ def setUp(self): # Send the join, it should return None (which is not an error) self.assertEqual( - self.get_success(self.handler.on_receive_pdu("test.serv", join_event)), + self.get_success( + federation_event_handler.on_receive_pdu("test.serv", join_event) + ), None, ) @@ -129,9 +132,10 @@ async def post_json(destination, path, data, headers=None, timeout=0): } ) + federation_event_handler = self.homeserver.get_federation_event_handler() with LoggingContext("test-context"): failure = self.get_failure( - self.handler.on_receive_pdu("test.serv", lying_event), + federation_event_handler.on_receive_pdu("test.serv", lying_event), FederationError, ) From c4fa4f37cbc734f9cd6354a5f2661efc30d73cac Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Aug 2021 10:15:50 +0100 Subject: [PATCH 23/28] Fix perf of fetching the same events many times. (#10703) The code to deduplicate repeated fetches of the same set of events was N^2 (over the number of events requested), which could lead to a process being completely wedged. The main fix is to deduplicate the returned deferreds so we only await on a deferred once rather than many times. Seperately, when handling the returned events from the defrered we only add the events we care about to the event map to be returned (so that we don't pay the price of inserting extraneous events into the dict). --- changelog.d/10703.bugfix | 1 + .../storage/databases/main/events_worker.py | 29 +++++++++++++++---- 2 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10703.bugfix diff --git a/changelog.d/10703.bugfix b/changelog.d/10703.bugfix new file mode 100644 index 000000000000..a5a4ecf8eedf --- /dev/null +++ b/changelog.d/10703.bugfix @@ -0,0 +1 @@ +Fix a regression introduced in v1.41.0 which affected the performance of concurrent fetches of large sets of events, in extreme cases causing the process to hang. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 375463e4e979..9501f00f3bb3 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -520,16 +520,26 @@ async def _get_events_from_cache_or_db( # We now look up if we're already fetching some of the events in the DB, # if so we wait for those lookups to finish instead of pulling the same # events out of the DB multiple times. - already_fetching: Dict[str, defer.Deferred] = {} + # + # Note: we might get the same `ObservableDeferred` back for multiple + # events we're already fetching, so we deduplicate the deferreds to + # avoid extraneous work (if we don't do this we can end up in a n^2 mode + # when we wait on the same Deferred N times, then try and merge the + # same dict into itself N times). + already_fetching_ids: Set[str] = set() + already_fetching_deferreds: Set[ + ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = set() for event_id in missing_events_ids: deferred = self._current_event_fetches.get(event_id) if deferred is not None: # We're already pulling the event out of the DB. Add the deferred # to the collection of deferreds to wait on. - already_fetching[event_id] = deferred.observe() + already_fetching_ids.add(event_id) + already_fetching_deferreds.add(deferred) - missing_events_ids.difference_update(already_fetching) + missing_events_ids.difference_update(already_fetching_ids) if missing_events_ids: log_ctx = current_context() @@ -569,18 +579,25 @@ async def _get_events_from_cache_or_db( with PreserveLoggingContext(): fetching_deferred.callback(missing_events) - if already_fetching: + if already_fetching_deferreds: # Wait for the other event requests to finish and add their results # to ours. results = await make_deferred_yieldable( defer.gatherResults( - already_fetching.values(), + (d.observe() for d in already_fetching_deferreds), consumeErrors=True, ) ).addErrback(unwrapFirstError) for result in results: - event_entry_map.update(result) + # We filter out events that we haven't asked for as we might get + # a *lot* of superfluous events back, and there is no point + # going through and inserting them all (which can take time). + event_entry_map.update( + (event_id, entry) + for event_id, entry in result.items() + if event_id in already_fetching_ids + ) if not allow_rejected: event_entry_map = { From e62cdbef1a499f428e48f98167b2b709d16c671d Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 27 Aug 2021 11:16:40 +0200 Subject: [PATCH 24/28] Improve ServerNoticeServlet to avoid duplicate requests (#10679) Fixes: #9544 --- changelog.d/10679.bugfix | 1 + synapse/rest/admin/__init__.py | 5 +- synapse/rest/admin/server_notice_servlet.py | 19 +- .../server_notices/server_notices_manager.py | 17 +- tests/rest/admin/test_server_notice.py | 450 ++++++++++++++++++ 5 files changed, 475 insertions(+), 17 deletions(-) create mode 100644 changelog.d/10679.bugfix create mode 100644 tests/rest/admin/test_server_notice.py diff --git a/changelog.d/10679.bugfix b/changelog.d/10679.bugfix new file mode 100644 index 000000000000..5c4061f6d552 --- /dev/null +++ b/changelog.d/10679.bugfix @@ -0,0 +1 @@ +Improve ServerNoticeServlet to avoid duplicate requests and add unit tests. \ No newline at end of file diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6e1c8736e1dd..b2514d9d0df4 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -223,7 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) - SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) @@ -247,6 +246,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) + # Some servlets only get registered for the main process. + if hs.config.worker_app is None: + SendServerNoticeServlet(hs).register(http_server) + def register_servlets_for_client_rest_resource( hs: "HomeServer", http_server: HttpServer diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index b5e4c474efc8..42201afc86d3 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -53,6 +53,8 @@ class SendServerNoticeServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() + self.server_notices_manager = hs.get_server_notices_manager() + self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) def register(self, json_resource: HttpServer): @@ -79,19 +81,22 @@ async def on_POST( # We grab the server notices manager here as its initialisation has a check for worker processes, # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the # admin api). - if not self.hs.get_server_notices_manager().is_enabled(): + if not self.server_notices_manager.is_enabled(): raise SynapseError(400, "Server notices are not enabled on this server") - user_id = body["user_id"] - UserID.from_string(user_id) - if not self.hs.is_mine_id(user_id): + target_user = UserID.from_string(body["user_id"]) + if not self.hs.is_mine(target_user): raise SynapseError(400, "Server notices can only be sent to local users") - event = await self.hs.get_server_notices_manager().send_notice( - user_id=body["user_id"], + if not await self.admin_handler.get_user(target_user): + raise NotFoundError("User not found") + + event = await self.server_notices_manager.send_notice( + user_id=target_user.to_string(), type=event_type, state_key=state_key, event_content=body["content"], + txn_id=txn_id, ) return 200, {"event_id": event.event_id} diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index f19075b76050..d87a53891740 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -12,26 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.events import EventBase from synapse.types import UserID, create_requester from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) SERVER_NOTICE_ROOM_TAG = "m.server_notice" class ServerNoticesManager: - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ - + def __init__(self, hs: "HomeServer"): self._store = hs.get_datastore() self._config = hs.config self._account_data_handler = hs.get_account_data_handler() @@ -58,6 +55,7 @@ async def send_notice( event_content: dict, type: str = EventTypes.Message, state_key: Optional[str] = None, + txn_id: Optional[str] = None, ) -> EventBase: """Send a notice to the given user @@ -68,6 +66,7 @@ async def send_notice( event_content: content of event to send type: type of event is_state_event: Is the event a state event + txn_id: The transaction ID. """ room_id = await self.get_or_create_notice_room_for_user(user_id) await self.maybe_invite_user_to_room(user_id, room_id) @@ -90,7 +89,7 @@ async def send_notice( event_dict["state_key"] = state_key event, _ = await self._event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, ratelimit=False + requester, event_dict, ratelimit=False, txn_id=txn_id ) return event diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py new file mode 100644 index 000000000000..fbceba325494 --- /dev/null +++ b/tests/rest/admin/test_server_notice.py @@ -0,0 +1,450 @@ +# Copyright 2021 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login, room, sync +from synapse.storage.roommember import RoomsForUser +from synapse.types import JsonDict + +from tests import unittest +from tests.unittest import override_config + + +class ServerNoticeTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.room_shutdown_handler = hs.get_room_shutdown_handler() + self.pagination_handler = hs.get_pagination_handler() + self.server_notices_manager = self.hs.get_server_notices_manager() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/send_server_notice" + + def test_no_auth(self): + """Try to send a server notice without authentication.""" + channel = self.make_request("POST", self.url) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """If the user is not a server admin, an error is returned.""" + channel = self.make_request( + "POST", + self.url, + access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_user_does_not_exist(self): + """Tests that a lookup for a user that does not exist returns a 404""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": "@unknown_person:test", "content": ""}, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": "@unknown_person:unknown_domain", + "content": "", + }, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + "Server notices can only be sent to local users", channel.json_body["error"] + ) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_invalid_parameter(self): + """If parameters are invalid, an error is returned.""" + + # no content, no user + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) + + # no content + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # no body + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user, "content": ""}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("'body' not in content", channel.json_body["error"]) + + # no msgtype + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user, "content": {"body": ""}}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("'msgtype' not in content", channel.json_body["error"]) + + def test_server_notice_disabled(self): + """Tests that server returns error if server notice is disabled""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": "", + }, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual( + "Server notices are not enabled on this server", channel.json_body["error"] + ) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice(self): + """ + Tests that sending two server notices is successfully, + the server uses the same room and do not send messages twice. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join(room=room_id, user=self.other_user, tok=self.other_user_token) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # invalidate cache of server notices room_ids + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has no new invites or memberships + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(room_id, self.other_user_token) + + self.assertEqual(len(messages), 2) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + self.assertEqual(messages[1]["content"]["body"], "test msg two") + self.assertEqual(messages[1]["sender"], "@notices:test") + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice_leave_room(self): + """ + Tests that sending a server notices is successfully. + The user leaves the room and the second message appears + in a new room. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + first_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(first_room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # user leaves the romm + self.helper.leave( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + + # user is not member anymore + self._check_invite_and_join_status(self.other_user, 0, 0) + + # invalidate cache of server notices room_ids + # if server tries to send to a cached room_id the user gets the message + # in old room + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + second_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=second_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(second_room_id, self.other_user_token) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg two") + self.assertEqual(messages[0]["sender"], "@notices:test") + # room has the same id + self.assertNotEqual(first_room_id, second_room_id) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice_delete_room(self): + """ + Tests that the user get server notice in a new room + after the first server notice room was deleted. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + first_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(first_room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # shut down and purge room + self.get_success( + self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user) + ) + self.get_success(self.pagination_handler.purge_room(first_room_id)) + + # user is not member anymore + self._check_invite_and_join_status(self.other_user, 0, 0) + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(first_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) + + # invalidate cache of server notices room_ids + # if server tries to send to a cached room_id it gives an error + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + second_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=second_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get message + messages = self._sync_and_get_messages(second_room_id, self.other_user_token) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg two") + self.assertEqual(messages[0]["sender"], "@notices:test") + # second room has new ID + self.assertNotEqual(first_room_id, second_room_id) + + def _check_invite_and_join_status( + self, user_id: str, expected_invites: int, expected_memberships: int + ) -> RoomsForUser: + """Check invite and room membership status of a user. + + Args + user_id: user to check + expected_invites: number of expected invites of this user + expected_memberships: number of expected room memberships of this user + Returns + room_ids from the rooms that the user is invited + """ + + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(user_id) + ) + self.assertEqual(expected_invites, len(invited_rooms)) + + room_ids = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(expected_memberships, len(room_ids)) + + return invited_rooms + + def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]: + """ + Do a sync and get messages of a room. + + Args + room_id: room that contains the messages + token: access token of user + + Returns + list of messages contained in the room + """ + channel = self.make_request( + "GET", "/_matrix/client/r0/sync", access_token=token + ) + self.assertEqual(channel.code, 200) + + # Get the messages + room = channel.json_body["rooms"]["join"][room_id] + messages = [ + x for x in room["timeline"]["events"] if x["type"] == "m.room.message" + ] + return messages From 029b7ad7b94d167b19d63a5dc777a806b0e073f3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 27 Aug 2021 07:08:02 -0400 Subject: [PATCH 25/28] Remove unused `compare_digest` function. (#10706) --- changelog.d/10706.misc | 1 + synapse/rest/client/register.py | 13 ------------- 2 files changed, 1 insertion(+), 13 deletions(-) create mode 100644 changelog.d/10706.misc diff --git a/changelog.d/10706.misc b/changelog.d/10706.misc new file mode 100644 index 000000000000..eed4aa58d621 --- /dev/null +++ b/changelog.d/10706.misc @@ -0,0 +1 @@ +Remove unused `compare_digest` function. diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 2781a0ea96df..7b5f49d635cb 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import hmac import logging import random from typing import List, Union @@ -60,18 +59,6 @@ from ._base import client_patterns, interactive_auth_handler -# We ought to be using hmac.compare_digest() but on older pythons it doesn't -# exist. It's a _really minor_ security flaw to use plain string comparison -# because the timing attack is so obscured by all the other code here it's -# unlikely to make much difference -if hasattr(hmac, "compare_digest"): - compare_digest = hmac.compare_digest -else: - - def compare_digest(a, b): - return a == b - - logger = logging.getLogger(__name__) From 051ddac53b733e5768488bac7548a0c31bf68982 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 27 Aug 2021 12:54:21 +0100 Subject: [PATCH 26/28] Clarifications to reverse_proxy.md (#10708) * Update reverse_proxy.md * Create 10708.doc --- changelog.d/10708.doc | 1 + docs/reverse_proxy.md | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10708.doc diff --git a/changelog.d/10708.doc b/changelog.d/10708.doc new file mode 100644 index 000000000000..99f9d69288c8 --- /dev/null +++ b/changelog.d/10708.doc @@ -0,0 +1 @@ +Minor clarifications to the documentation for reverse proxies. diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 5f8d20129e1a..bc351d604e40 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -64,6 +64,9 @@ server { server_name matrix.example.com; location ~* ^(\/_matrix|\/_synapse\/client) { + # note: do not add a path (even a single /) after the port in `proxy_pass`, + # otherwise nginx will canonicalise the URI and cause signature verification + # errors. proxy_pass http://localhost:8008; proxy_set_header X-Forwarded-For $remote_addr; proxy_set_header X-Forwarded-Proto $scheme; @@ -76,10 +79,7 @@ server { } ``` -**NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will -canonicalise/normalise the URI. - -### Caddy 1 +### Caddy v1 ``` matrix.example.com { @@ -99,7 +99,7 @@ example.com:8448 { } ``` -### Caddy 2 +### Caddy v2 ``` matrix.example.com { From 54aa7047ebf0d2605e31bdd4933effc4eb63813b Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Fri, 27 Aug 2021 15:19:17 +0100 Subject: [PATCH 27/28] Removed page summaries from the top of installation and contributing doc pages (#10711) - Removed page summaries from CONTRIBUTING and installation pages as this information was already in the table of contents on the right hand side - Fixed some broken links in CONTRIBUTING - Added margin-right tag for when table of contents is being shown (otherwise the text in the page sometimes overlaps with it) --- CONTRIBUTING.md | 49 +++++++----------------- changelog.d/10711.doc | 1 + docs/setup/installation.md | 39 ------------------- docs/website_files/table-of-contents.css | 7 +++- 4 files changed, 21 insertions(+), 75 deletions(-) create mode 100644 changelog.d/10711.doc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cd6c34df85b1..31d0a47fdf52 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,30 +2,6 @@ Welcome to Synapse This document aims to get you started with contributing to this repo! -- [1. Who can contribute to Synapse?](#1-who-can-contribute-to-synapse) -- [2. What do I need?](#2-what-do-i-need) -- [3. Get the source.](#3-get-the-source) -- [4. Install the dependencies](#4-install-the-dependencies) - * [Under Unix (macOS, Linux, BSD, ...)](#under-unix-macos-linux-bsd-) - * [Under Windows](#under-windows) -- [5. Get in touch.](#5-get-in-touch) -- [6. Pick an issue.](#6-pick-an-issue) -- [7. Turn coffee and documentation into code and documentation!](#7-turn-coffee-and-documentation-into-code-and-documentation) -- [8. Test, test, test!](#8-test-test-test) - * [Run the linters.](#run-the-linters) - * [Run the unit tests.](#run-the-unit-tests-twisted-trial) - * [Run the integration tests (SyTest).](#run-the-integration-tests-sytest) - * [Run the integration tests (Complement).](#run-the-integration-tests-complement) -- [9. Submit your patch.](#9-submit-your-patch) - * [Changelog](#changelog) - + [How do I know what to call the changelog file before I create the PR?](#how-do-i-know-what-to-call-the-changelog-file-before-i-create-the-pr) - + [Debian changelog](#debian-changelog) - * [Sign off](#sign-off) -- [10. Turn feedback into better code.](#10-turn-feedback-into-better-code) -- [11. Find a new issue.](#11-find-a-new-issue) -- [Notes for maintainers on merging PRs etc](#notes-for-maintainers-on-merging-prs-etc) -- [Conclusion](#conclusion) - # 1. Who can contribute to Synapse? Everyone is welcome to contribute code to [matrix.org @@ -35,7 +11,7 @@ follow a simple 'inbound=outbound' model for contributions: the act of submitting an 'inbound' contribution means that the contributor agrees to license the code under the same terms as the project's overall 'outbound' license - in our case, this is almost always Apache Software License v2 (see -[LICENSE](LICENSE)). +[LICENSE](https://github.com/matrix-org/synapse/blob/develop/LICENSE)). # 2. What do I need? @@ -98,17 +74,20 @@ to work on. # 7. Turn coffee and documentation into code and documentation! -Synapse's code style is documented [here](docs/code_style.md). Please follow -it, including the conventions for the [sample configuration -file](docs/code_style.md#configuration-file-format). +Synapse's code style is documented +[here](https://matrix-org.github.io/synapse/develop/code_style.html). +Please follow it, including the conventions for the +[sample configuration file](https://matrix-org.github.io/synapse/develop/code_style.html#configuration-file-format). -There is a growing amount of documentation located in the [docs](docs) +There is a growing amount of documentation located in the +[docs](https://github.com/matrix-org/synapse/tree/develop/docs) directory. This documentation is intended primarily for sysadmins running their -own Synapse instance, as well as developers interacting externally with -Synapse. [docs/dev](docs/dev) exists primarily to house documentation for -Synapse developers. [docs/admin_api](docs/admin_api) houses documentation -regarding Synapse's Admin API, which is used mostly by sysadmins and external -service developers. +own Synapse instance, as well as developers interacting externally with Synapse. +[docs/development](https://github.com/matrix-org/synapse/tree/develop/docs/development) +exists primarily to house documentation for Synapse developers. +[docs/admin_api](https://github.com/matrix-org/synapse/tree/develop/docs/admin_api) +houses documentation regarding Synapse's Admin API, which is used mostly by sysadmins +and external service developers. If you add new files added to either of these folders, please use [GitHub-Flavoured Markdown](https://guides.github.com/features/mastering-markdown/). @@ -431,7 +410,7 @@ By now, you know the drill! # Notes for maintainers on merging PRs etc There are some notes for those with commit access to the project on how we -manage git [here](docs/development/git.md). +manage git [here](https://matrix-org.github.io/synapse/develop/development/git.html). # Conclusion diff --git a/changelog.d/10711.doc b/changelog.d/10711.doc new file mode 100644 index 000000000000..c495f98be859 --- /dev/null +++ b/changelog.d/10711.doc @@ -0,0 +1 @@ +Removed table of contents from the top of installation and contributing documentation pages. \ No newline at end of file diff --git a/docs/setup/installation.md b/docs/setup/installation.md index 8540a7b0c10d..06f869cd75fa 100644 --- a/docs/setup/installation.md +++ b/docs/setup/installation.md @@ -1,44 +1,5 @@ # Installation Instructions -There are 3 steps to follow under **Installation Instructions**. - -- [Installation Instructions](#installation-instructions) - - [Choosing your server name](#choosing-your-server-name) - - [Installing Synapse](#installing-synapse) - - [Installing from source](#installing-from-source) - - [Platform-specific prerequisites](#platform-specific-prerequisites) - - [Debian/Ubuntu/Raspbian](#debianubunturaspbian) - - [ArchLinux](#archlinux) - - [CentOS/Fedora](#centosfedora) - - [macOS](#macos) - - [OpenSUSE](#opensuse) - - [OpenBSD](#openbsd) - - [Windows](#windows) - - [Prebuilt packages](#prebuilt-packages) - - [Docker images and Ansible playbooks](#docker-images-and-ansible-playbooks) - - [Debian/Ubuntu](#debianubuntu) - - [Matrix.org packages](#matrixorg-packages) - - [Downstream Debian packages](#downstream-debian-packages) - - [Downstream Ubuntu packages](#downstream-ubuntu-packages) - - [Fedora](#fedora) - - [OpenSUSE](#opensuse-1) - - [SUSE Linux Enterprise Server](#suse-linux-enterprise-server) - - [ArchLinux](#archlinux-1) - - [Void Linux](#void-linux) - - [FreeBSD](#freebsd) - - [OpenBSD](#openbsd-1) - - [NixOS](#nixos) - - [Setting up Synapse](#setting-up-synapse) - - [Using PostgreSQL](#using-postgresql) - - [TLS certificates](#tls-certificates) - - [Client Well-Known URI](#client-well-known-uri) - - [Email](#email) - - [Registering a user](#registering-a-user) - - [Setting up a TURN server](#setting-up-a-turn-server) - - [URL previews](#url-previews) - - [Troubleshooting Installation](#troubleshooting-installation) - - ## Choosing your server name It is important to choose the name for your server before you install Synapse, diff --git a/docs/website_files/table-of-contents.css b/docs/website_files/table-of-contents.css index d16bb3b9886b..1b6f44b66a2e 100644 --- a/docs/website_files/table-of-contents.css +++ b/docs/website_files/table-of-contents.css @@ -1,3 +1,7 @@ +:root { + --pagetoc-width: 250px; +} + @media only screen and (max-width:1439px) { .sidetoc { display: none; @@ -8,6 +12,7 @@ main { position: relative; margin-left: 100px !important; + margin-right: var(--pagetoc-width) !important; } .sidetoc { margin-left: auto; @@ -18,7 +23,7 @@ } .pagetoc { position: fixed; - width: 250px; + width: var(--pagetoc-width); overflow: auto; right: 20px; height: calc(100% - var(--menu-bar-height)); From 8f98260552f4f39f003bc1fbf6da159d9138081d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 27 Aug 2021 16:33:41 +0100 Subject: [PATCH 28/28] Fix incompatibility with Twisted < 21. (#10713) Turns out that the functionality added in #10546 to skip TLS was incompatible with older Twisted versions, so we need to be a bit more inventive. Also, add a test to (hopefully) not break this in future. Sadly, testing TLS is really hard. --- changelog.d/10713.bugfix | 1 + mypy.ini | 1 + synapse/handlers/send_email.py | 65 ++++++++++++----- tests/handlers/test_send_email.py | 112 ++++++++++++++++++++++++++++++ tests/server.py | 15 +++- 5 files changed, 173 insertions(+), 21 deletions(-) create mode 100644 changelog.d/10713.bugfix create mode 100644 tests/handlers/test_send_email.py diff --git a/changelog.d/10713.bugfix b/changelog.d/10713.bugfix new file mode 100644 index 000000000000..e8caf3d23aaa --- /dev/null +++ b/changelog.d/10713.bugfix @@ -0,0 +1 @@ +Fix a regression introduced in Synapse 1.41 which broke email transmission on Systems using older versions of the Twisted library. diff --git a/mypy.ini b/mypy.ini index e1b9405daa85..349efe37bbc4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -87,6 +87,7 @@ files = tests/test_utils, tests/handlers/test_password_providers.py, tests/handlers/test_room_summary.py, + tests/handlers/test_send_email.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_itertools.py, diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index dda9659c11c2..a31fe3e3c7ef 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -19,9 +19,12 @@ from io import BytesIO from typing import TYPE_CHECKING, Optional +from pkg_resources import parse_version + +import twisted from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IReactorTCP -from twisted.mail.smtp import ESMTPSenderFactory +from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP +from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory from synapse.logging.context import make_deferred_yieldable @@ -30,6 +33,19 @@ logger = logging.getLogger(__name__) +_is_old_twisted = parse_version(twisted.__version__) < parse_version("21") + + +class _NoTLSESMTPSender(ESMTPSender): + """Extend ESMTPSender to disable TLS + + Unfortunately, before Twisted 21.2, ESMTPSender doesn't give an easy way to disable + TLS, so we override its internal method which it uses to generate a context factory. + """ + + def _getContextFactory(self) -> Optional[IOpenSSLContextFactory]: + return None + async def _sendmail( reactor: IReactorTCP, @@ -42,7 +58,7 @@ async def _sendmail( password: Optional[bytes] = None, require_auth: bool = False, require_tls: bool = False, - tls_hostname: Optional[str] = None, + enable_tls: bool = True, ) -> None: """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests @@ -57,24 +73,37 @@ async def _sendmail( password: password to give when authenticating require_auth: if auth is not offered, fail the request require_tls: if TLS is not offered, fail the reqest - tls_hostname: TLS hostname to check for. None to disable TLS. + enable_tls: True to enable TLS. If this is False and require_tls is True, + the request will fail. """ msg = BytesIO(msg_bytes) - d: "Deferred[object]" = Deferred() - factory = ESMTPSenderFactory( - username, - password, - from_addr, - to_addr, - msg, - d, - heloFallback=True, - requireAuthentication=require_auth, - requireTransportSecurity=require_tls, - hostname=tls_hostname, - ) + def build_sender_factory(**kwargs) -> ESMTPSenderFactory: + return ESMTPSenderFactory( + username, + password, + from_addr, + to_addr, + msg, + d, + heloFallback=True, + requireAuthentication=require_auth, + requireTransportSecurity=require_tls, + **kwargs, + ) + + if _is_old_twisted: + # before twisted 21.2, we have to override the ESMTPSender protocol to disable + # TLS + factory = build_sender_factory() + + if not enable_tls: + factory.protocol = _NoTLSESMTPSender + else: + # for twisted 21.2 and later, there is a 'hostname' parameter which we should + # set to enable TLS. + factory = build_sender_factory(hostname=smtphost if enable_tls else None) # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] @@ -154,5 +183,5 @@ async def send_email( password=self._smtp_pass, require_auth=self._smtp_user is not None, require_tls=self._require_transport_security, - tls_hostname=self._smtp_host if self._enable_tls else None, + enable_tls=self._enable_tls, ) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py new file mode 100644 index 000000000000..6f77b1237c97 --- /dev/null +++ b/tests/handlers/test_send_email.py @@ -0,0 +1,112 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Tuple + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.address import IPv4Address +from twisted.internet.defer import ensureDeferred +from twisted.mail import interfaces, smtp + +from tests.server import FakeTransport +from tests.unittest import HomeserverTestCase + + +@implementer(interfaces.IMessageDelivery) +class _DummyMessageDelivery: + def __init__(self): + # (recipient, message) tuples + self.messages: List[Tuple[smtp.Address, bytes]] = [] + + def receivedHeader(self, helo, origin, recipients): + return None + + def validateFrom(self, helo, origin): + return origin + + def record_message(self, recipient: smtp.Address, message: bytes): + self.messages.append((recipient, message)) + + def validateTo(self, user: smtp.User): + return lambda: _DummyMessage(self, user) + + +@implementer(interfaces.IMessageSMTP) +class _DummyMessage: + """IMessageSMTP implementation which saves the message delivered to it + to the _DummyMessageDelivery object. + """ + + def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User): + self._delivery = delivery + self._user = user + self._buffer: List[bytes] = [] + + def lineReceived(self, line): + self._buffer.append(line) + + def eomReceived(self): + message = b"\n".join(self._buffer) + b"\n" + self._delivery.record_message(self._user.dest, message) + return defer.succeed(b"saved") + + def connectionLost(self): + pass + + +class SendEmailHandlerTestCase(HomeserverTestCase): + def test_send_email(self): + """Happy-path test that we can send email to a non-TLS server.""" + h = self.hs.get_send_email_handler() + d = ensureDeferred( + h.send_email( + "foo@bar.com", "test subject", "Tests", "HTML content", "Text content" + ) + ) + # there should be an attempt to connect to localhost:25 + self.assertEqual(len(self.reactor.tcpClients), 1) + (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ + 0 + ] + self.assertEqual(host, "localhost") + self.assertEqual(port, 25) + + # wire it up to an SMTP server + message_delivery = _DummyMessageDelivery() + server_protocol = smtp.ESMTP() + server_protocol.delivery = message_delivery + # make sure that the server uses the test reactor to set timeouts + server_protocol.callLater = self.reactor.callLater # type: ignore[assignment] + + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor)) + server_protocol.makeConnection( + FakeTransport( + client_protocol, + self.reactor, + peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + ) + ) + + # the message should now get delivered + self.get_success(d, by=0.1) + + # check it arrived + self.assertEqual(len(message_delivery.messages), 1) + user, msg = message_delivery.messages.pop() + self.assertEqual(str(user), "foo@bar.com") + self.assertIn(b"Subject: test subject", msg) diff --git a/tests/server.py b/tests/server.py index 6fddd3b30558..b861c7b866f8 100644 --- a/tests/server.py +++ b/tests/server.py @@ -10,9 +10,10 @@ from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier -from twisted.internet.defer import Deferred, fail, succeed +from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( + IAddress, IHostnameResolver, IProtocol, IPullProducer, @@ -511,6 +512,9 @@ class FakeTransport: will get called back for connectionLost() notifications etc. """ + _peer_address: Optional[IAddress] = attr.ib(default=None) + """The value to be returend by getPeer""" + disconnecting = False disconnected = False connected = True @@ -519,7 +523,7 @@ class FakeTransport: autoflush = attr.ib(default=True) def getPeer(self): - return None + return self._peer_address def getHost(self): return None @@ -572,7 +576,12 @@ def registerProducer(self, producer, streaming): self.producerStreaming = streaming def _produce(): - d = self.producer.resumeProducing() + if not self.producer: + # we've been unregistered + return + # some implementations of IProducer (for example, FileSender) + # don't return a deferred. + d = maybeDeferred(self.producer.resumeProducing) d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) if not streaming: