This repository has been archived by the owner on Apr 26, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
171 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,11 +31,12 @@ | |
Membership, | ||
RelationTypes, | ||
) | ||
from synapse.api.errors import Codes, HttpResponseException | ||
from synapse.api.errors import Code, Codes, HttpResponseException | ||
from synapse.handlers.pagination import PurgeStatus | ||
from synapse.rest import admin | ||
from synapse.rest.client import account, directory, login, profile, room, sync | ||
from synapse.server import HomeServer | ||
from synapse.spam_checker_api import ALLOW, Decision | ||
from synapse.types import JsonDict, RoomAlias, UserID, create_requester | ||
from synapse.util import Clock | ||
from synapse.util.stringutils import random_string | ||
|
@@ -676,9 +677,9 @@ def test_post_room_invitees_ratelimit(self) -> None: | |
channel = self.make_request("POST", "/createRoom", content) | ||
self.assertEqual(200, channel.code) | ||
|
||
def test_spam_checker_may_join_room(self) -> None: | ||
def test_spam_checker_may_join_room_old(self) -> None: | ||
"""Tests that the user_may_join_room spam checker callback is correctly bypassed | ||
when creating a new room. | ||
when creating a new room (old-style API, returning a boolean). | ||
""" | ||
|
||
async def user_may_join_room( | ||
|
@@ -700,6 +701,30 @@ async def user_may_join_room( | |
|
||
self.assertEqual(join_mock.call_count, 0) | ||
|
||
def test_spam_checker_may_join_room(self) -> None: | ||
"""Tests that the user_may_join_room spam checker callback is correctly bypassed | ||
when creating a new room. | ||
""" | ||
|
||
async def user_may_join_room( | ||
mxid: str, | ||
room_id: str, | ||
is_invite: bool, | ||
) -> Decision: | ||
return Code.FORBIDDEN | ||
|
||
join_mock = Mock(side_effect=user_may_join_room) | ||
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) | ||
|
||
channel = self.make_request( | ||
"POST", | ||
"/createRoom", | ||
{}, | ||
) | ||
self.assertEqual(channel.code, 200, channel.json_body) | ||
|
||
self.assertEqual(join_mock.call_count, 0) | ||
|
||
|
||
class RoomTopicTestCase(RoomBase): | ||
"""Tests /rooms/$room_id/topic REST events.""" | ||
|
@@ -910,9 +935,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) | ||
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) | ||
|
||
def test_spam_checker_may_join_room(self) -> None: | ||
def test_spam_checker_may_join_room_old(self) -> None: | ||
"""Tests that the user_may_join_room spam checker callback is correctly called | ||
and blocks room joins when needed. | ||
and blocks room joins when needed (old-style API, return a boolean). | ||
""" | ||
|
||
# Register a dummy callback. Make it allow all room joins for now. | ||
|
@@ -967,6 +992,63 @@ async def user_may_join_room( | |
return_value = False | ||
self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) | ||
|
||
def test_spam_checker_may_join_room(self) -> None: | ||
"""Tests that the user_may_join_room spam checker callback is correctly called | ||
and blocks room joins when needed. | ||
""" | ||
|
||
# Register a dummy callback. Make it allow all room joins for now. | ||
return_value: Decision = ALLOW | ||
|
||
async def user_may_join_room( | ||
userid: str, | ||
room_id: str, | ||
is_invited: bool, | ||
) -> Decision: | ||
return return_value | ||
|
||
callback_mock = Mock(side_effect=user_may_join_room) | ||
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) | ||
|
||
# Join a first room, without being invited to it. | ||
self.helper.join(self.room1, self.user2, tok=self.tok2) | ||
|
||
# Check that the callback was called with the right arguments. | ||
expected_call_args = ( | ||
( | ||
self.user2, | ||
self.room1, | ||
False, | ||
), | ||
) | ||
self.assertEqual( | ||
callback_mock.call_args, | ||
expected_call_args, | ||
callback_mock.call_args, | ||
) | ||
|
||
# Join a second room, this time with an invite for it. | ||
self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1) | ||
self.helper.join(self.room2, self.user2, tok=self.tok2) | ||
|
||
# Check that the callback was called with the right arguments. | ||
expected_call_args = ( | ||
( | ||
self.user2, | ||
self.room2, | ||
True, | ||
), | ||
) | ||
self.assertEqual( | ||
callback_mock.call_args, | ||
expected_call_args, | ||
callback_mock.call_args, | ||
) | ||
|
||
# Now make the callback deny all room joins, and check that a join actually fails. | ||
return_value = Code.FORBIDDEN | ||
self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) | ||
|
||
|
||
class RoomJoinRatelimitTestCase(RoomBase): | ||
user_id = "@sid1:red" | ||
|
@@ -2586,7 +2668,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |
|
||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) | ||
|
||
def test_threepid_invite_spamcheck(self) -> None: | ||
def test_threepid_invite_spamcheck_old(self) -> None: | ||
# Mock a few functions to prevent the test from failing due to failing to talk to | ||
# a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we | ||
# can check its call_count later on during the test. | ||
|
@@ -2640,3 +2722,58 @@ def test_threepid_invite_spamcheck(self) -> None: | |
|
||
# Also check that it stopped before calling _make_and_store_3pid_invite. | ||
make_invite_mock.assert_called_once() | ||
|
||
def test_threepid_invite_spamcheck(self) -> None: | ||
# Mock a few functions to prevent the test from failing due to failing to talk to | ||
# a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we | ||
# can check its call_count later on during the test. | ||
make_invite_mock = Mock(return_value=make_awaitable(0)) | ||
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock | ||
self.hs.get_identity_handler().lookup_3pid = Mock( | ||
return_value=make_awaitable(None), | ||
) | ||
|
||
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it | ||
# allow everything for now. | ||
mock = Mock(return_value=make_awaitable(ALLOW)) | ||
self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) | ||
|
||
# Send a 3PID invite into the room and check that it succeeded. | ||
email_to_invite = "[email protected]" | ||
channel = self.make_request( | ||
method="POST", | ||
path="/rooms/" + self.room_id + "/invite", | ||
content={ | ||
"id_server": "example.com", | ||
"id_access_token": "sometoken", | ||
"medium": "email", | ||
"address": email_to_invite, | ||
}, | ||
access_token=self.tok, | ||
) | ||
self.assertEqual(channel.code, 200) | ||
|
||
# Check that the callback was called with the right params. | ||
mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) | ||
|
||
# Check that the call to send the invite was made. | ||
make_invite_mock.assert_called_once() | ||
|
||
# Now change the return value of the callback to deny any invite and test that | ||
# we can't send the invite. | ||
mock.return_value = make_awaitable(Code.FORBIDDEN) | ||
channel = self.make_request( | ||
method="POST", | ||
path="/rooms/" + self.room_id + "/invite", | ||
content={ | ||
"id_server": "example.com", | ||
"id_access_token": "sometoken", | ||
"medium": "email", | ||
"address": email_to_invite, | ||
}, | ||
access_token=self.tok, | ||
) | ||
self.assertEqual(channel.code, 403) | ||
|
||
# Also check that it stopped before calling _make_and_store_3pid_invite. | ||
make_invite_mock.assert_called_once() |