Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Ratelimit 3PID /requestToken API #9238

Merged
merged 6 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9238.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ratelimited to 3PID `/requestToken` API.
6 changes: 5 additions & 1 deletion docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
#
# The defaults are as shown below.
#
Expand Down Expand Up @@ -857,7 +858,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# remote:
# per_second: 0.01
# burst_count: 3

#
#rc_3pid_validation:
# per_second: 0.003
# burst_count: 5

# Ratelimiting settings for incoming federation
#
Expand Down
2 changes: 1 addition & 1 deletion synapse/config/_base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class RootConfig:
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
ratelimit: ratelimiting.RatelimitConfig
ratelimiting: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
Expand Down
13 changes: 11 additions & 2 deletions synapse/config/ratelimiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
defaults={"per_second": 0.17, "burst_count": 3.0},
):
self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = config.get("burst_count", defaults["burst_count"])
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because now mypy hates me. I spent a while trying to figure out how to make the type hints in synapse/app/ratelimiting.py accept floats or ints for burst count and I almost cried.

Anyway, it doesn't really make sense for burst count to be a float so 🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised there isn't a Number = Union[int, float] 🤷 We could add that to synapse.types if it was useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend not looking at numbers.Number



class FederationRateLimitConfig:
Expand Down Expand Up @@ -102,6 +102,11 @@ def read_config(self, config, **kwargs):
defaults={"per_second": 0.01, "burst_count": 3},
)

self.rc_3pid_validation = RateLimitConfig(
config.get("rc_3pid_validation") or {},
defaults={"per_second": 0.003, "burst_count": 5},
)

def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##
Expand Down Expand Up @@ -131,6 +136,7 @@ def generate_config_section(self, **kwargs):
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
#
# The defaults are as shown below.
#
Expand Down Expand Up @@ -164,7 +170,10 @@ def generate_config_section(self, **kwargs):
# remote:
# per_second: 0.01
# burst_count: 3

#
#rc_3pid_validation:
# per_second: 0.003
# burst_count: 5

# Ratelimiting settings for incoming federation
#
Expand Down
28 changes: 28 additions & 0 deletions synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
HttpResponseException,
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
Expand Down Expand Up @@ -57,6 +59,32 @@ def __init__(self, hs):

self._web_client_location = hs.config.invite_client_location

# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)

def ratelimit_request_token_requests(
self, request: SynapseRequest, medium: str, address: str,
):
"""Used to ratelimit requests to `/requestToken` by IP and address.
clokep marked this conversation as resolved.
Show resolved Hide resolved

Args:
request: The associated request
medium: The type of threepid, e.g. "msisdn" or "email"
address: The actual threepid ID, e.g. the phone number or email address
"""

self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))

async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
Expand Down
12 changes: 10 additions & 2 deletions synapse/rest/client/v2_alpha/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
Expand Down Expand Up @@ -103,6 +103,8 @@ async def on_POST(self, request):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)

self.identity_handler.ratelimit_request_token_requests(request, "email", email)

# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
Expand Down Expand Up @@ -379,6 +381,8 @@ async def on_POST(self, request):
Codes.THREEPID_DENIED,
)

self.identity_handler.ratelimit_request_token_requests(request, "email", email)

if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
Expand Down Expand Up @@ -430,7 +434,7 @@ async def on_POST(self, request):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
Expand Down Expand Up @@ -458,6 +462,10 @@ async def on_POST(self, request):
Codes.THREEPID_DENIED,
)

self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)

if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
Expand Down
6 changes: 6 additions & 0 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ async def on_POST(self, request):
Codes.THREEPID_DENIED,
)

self.identity_handler.ratelimit_request_token_requests(request, "email", email)

existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
Expand Down Expand Up @@ -205,6 +207,10 @@ async def on_POST(self, request):
Codes.THREEPID_DENIED,
)

self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)

existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
Expand Down
90 changes: 84 additions & 6 deletions tests/rest/client/v2_alpha/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes
from synapse.api.errors import Codes, HttpResponseException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
Expand Down Expand Up @@ -112,6 +112,56 @@ def test_basic_password_reset(self):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)

@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_email(self):
"""Test that we ratelimit /requestToken for the same email.
"""
old_password = "monkey"
new_password = "kangeroo"

user_id = self.register_user("kermit", old_password)
self.login("kermit", old_password)

email = "[email protected]"

# Add a threepid
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
medium="email",
address=email,
validated_at=0,
added_at=0,
)
)

def reset(ip):
client_secret = "foobar"
session_id = self._request_token(email, client_secret, ip)

self.assertEquals(len(self.email_attempts), 1)
link = self._get_link_from_email()

self._validate_token(link)

self._reset_password(new_password, session_id, client_secret)

self.email_attempts.clear()

# We expect to be able to make three requests before getting rate
# limited.
#
# We change IPs to ensure that we're not being ratelimited due to the
# same IP
reset("127.0.0.1")
reset("127.0.0.2")
reset("127.0.0.3")

with self.assertRaises(HttpResponseException) as cm:
reset("127.0.0.4")

self.assertEqual(cm.exception.code, 429)

def test_basic_password_reset_canonicalise_email(self):
"""Test basic password reset flow
Request password reset with different spelling
Expand Down Expand Up @@ -239,13 +289,18 @@ def test_password_reset_bad_email_inhibit_error(self):

self.assertIsNotNone(session_id)

def _request_token(self, email, client_secret):
def _request_token(self, email, client_secret, ip="127.0.0.1"):
channel = self.make_request(
"POST",
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
client_ip=ip,
)
self.assertEquals(200, channel.code, channel.result)

if channel.code != 200:
raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"],
)

return channel.json_body["sid"]

Expand Down Expand Up @@ -509,6 +564,21 @@ def test_add_email_address_casefold(self):
def test_address_trim(self):
self.get_success(self._add_email(" [email protected] ", "[email protected]"))

@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_ip(self):
"""Tests that adding emails is ratelimited by IP
"""

# We expect to be able to set three emails before getting ratelimited.
self.get_success(self._add_email("[email protected]", "[email protected]"))
self.get_success(self._add_email("[email protected]", "[email protected]"))
self.get_success(self._add_email("[email protected]", "[email protected]"))

with self.assertRaises(HttpResponseException) as cm:
self.get_success(self._add_email("[email protected]", "[email protected]"))

self.assertEqual(cm.exception.code, 429)

def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed
"""
Expand Down Expand Up @@ -777,7 +847,11 @@ def _request_token(
body["next_link"] = next_link

channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
self.assertEquals(expect_code, channel.code, channel.result)

if channel.code != expect_code:
raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"],
)

return channel.json_body.get("sid")

Expand Down Expand Up @@ -823,10 +897,12 @@ def _get_link_from_email(self):
def _add_email(self, request_email, expected_email):
"""Test adding an email to profile
"""
previous_email_attempts = len(self.email_attempts)

client_secret = "foobar"
session_id = self._request_token(request_email, client_secret)

self.assertEquals(len(self.email_attempts), 1)
self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
link = self._get_link_from_email()

self._validate_token(link)
Expand Down Expand Up @@ -855,4 +931,6 @@ def _add_email(self, request_email, expected_email):

self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])

threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
self.assertIn(expected_email, threepids)
9 changes: 7 additions & 2 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
_ip = attr.ib(type=str, default="127.0.0.1")
_producer = None

@property
Expand Down Expand Up @@ -120,7 +121,7 @@ def requestDone(self, _self):
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address("TCP", "127.0.0.1", 3423)
return address.IPv4Address("TCP", self._ip, 3423)

def getHost(self):
return None
Expand Down Expand Up @@ -196,6 +197,7 @@ def make_request(
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Make a web request using the given method, path and content, and render it
Expand Down Expand Up @@ -223,6 +225,9 @@ def make_request(
will pump the reactor until the the renderer tells the channel the request
is finished.

client_ip: The IP to use as the requesting IP. Useful for testing
ratelimiting.

Returns:
channel
"""
Expand Down Expand Up @@ -250,7 +255,7 @@ def make_request(
if isinstance(content, str):
content = content.encode("utf8")

channel = FakeChannel(site, reactor)
channel = FakeChannel(site, reactor, ip=client_ip)

req = request(channel)
req.content = BytesIO(content)
Expand Down
5 changes: 5 additions & 0 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def make_request(
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Create a SynapseRequest at the path using the method and containing the
Expand All @@ -410,6 +411,9 @@ def make_request(

custom_headers: (name, value) pairs to add as request headers

client_ip: The IP to use as the requesting IP. Useful for testing
ratelimiting.

Returns:
The FakeChannel object which stores the result of the request.
"""
Expand All @@ -426,6 +430,7 @@ def make_request(
content_is_form,
await_result,
custom_headers,
client_ip,
)

def setup_test_homeserver(self, *args, **kwargs):
Expand Down
Loading