Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Allow rate limiters to passively record actions they cannot limit #13253

Merged
merged 6 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13253.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparatory work for a per-room rate limiter on joins.
94 changes: 82 additions & 12 deletions synapse/api/ratelimiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,33 @@ class Ratelimiter:
"""
Ratelimit actions marked by arbitrary keys.

(Note that the source code speaks of "actions" and "burst_count" rather than
"tokens" and a "bucket_size".)

This is a "leaky bucket as a meter". For each key to be tracked there is a bucket
containing some number 0 <= T <= `burst_count` of tokens corresponding to previously
permitted requests for that key. Each bucket starts empty, and gradually leaks
tokens at a rate of `rate_hz`.

Upon an incoming request, we must determine:
- the key that this request falls under (which bucket to inspect), and
- the cost C of this request in tokens.
Then, if there is room in the bucket for C tokens (T + C <= `burst_count`),
the request is permitted and `cost` tokens are added to the bucket.
Otherwise the request is denied, and the bucket continues to hold T tokens.

This means that the limiter enforces an average request frequency of `rate_hz`,
while accumulating a buffer of up to `burst_count` requests which can be consumed
instantaneously.

The tricky bit is the leaking. We do not want to have a periodic process which
leaks every bucket! Instead, we track
- the time point when the bucket was last completely empty, and
- how many tokens have added to the bucket permitted since then.
Then for each incoming request, we can calculate how many tokens have leaked
since this time point, and use that to decide if we should accept or reject the
request.

Args:
clock: A homeserver clock, for retrieving the current time
rate_hz: The long term number of actions that can be performed in a second.
Expand All @@ -41,14 +68,30 @@ def __init__(
self.burst_count = burst_count
self.store = store

# A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type
# to a tuple representing:
# * How many times an action has occurred since a point in time
# * The point in time
# * The rate_hz of this particular entry. This can vary per request
# An ordered dictionary representing the token buckets tracked by this rate
# limiter. Each entry maps a key of arbitrary type to a tuple representing:
# * The number of tokens currently in the bucket,
# * The time point when the bucket was last completely empty, and
# * The rate_hz (leak rate) of this particular bucket.
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()

def _get_key(
self, requester: Optional[Requester], key: Optional[Hashable]
) -> Hashable:
"""Use the requester's MXID as a fallback key if no key is provided."""
if key is None:
if not requester:
raise ValueError("Must supply at least one of `requester` or `key`")

key = requester.user.to_string()
return key

def _get_action_counts(
self, key: Hashable, time_now_s: float
) -> Tuple[float, float, float]:
"""Retrieve the action counts, with a fallback representing an empty bucket."""
return self.actions.get(key, (0.0, time_now_s, 0.0))

async def can_do_action(
self,
requester: Optional[Requester],
Expand Down Expand Up @@ -88,11 +131,7 @@ async def can_do_action(
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
if key is None:
if not requester:
raise ValueError("Must supply at least one of `requester` or `key`")

key = requester.user.to_string()
key = self._get_key(requester, key)

if requester:
# Disable rate limiting of users belonging to any AS that is configured
Expand Down Expand Up @@ -121,7 +160,7 @@ async def can_do_action(
self._prune_message_counts(time_now_s)

# Check if there is an existing count entry for this key
action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
action_count, time_start, _ = self._get_action_counts(key, time_now_s)

# Check whether performing another action is allowed
time_delta = time_now_s - time_start
Expand Down Expand Up @@ -164,6 +203,37 @@ async def can_do_action(

return allowed, time_allowed

def record_action(
self,
requester: Optional[Requester],
key: Optional[Hashable] = None,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
) -> None:
"""Record that an action(s) took place, even if they violate the rate limit.

This is useful for tracking the frequency of events that happen across
federation which we still want to impose local rate limits on. For instance, if
we are alice.com monitoring a particular room, we cannot prevent bob.com
from joining users to that room. However, we can track the number of recent
joins in the room and refuse to serve new joins ourselves if there have been too
many in the room across both homeservers.

Args:
requester: The requester that is doing the action, if any.
key: An arbitrary key used to classify an action. Defaults to the
requester's user ID.
n_actions: The number of times the user wants to do this action. If the user
cannot do all of the actions, the user's action count is not incremented
at all.
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
"""
key = self._get_key(requester, key)
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
self.actions[key] = (action_count + n_actions, time_start, rate_hz)

def _prune_message_counts(self, time_now_s: float) -> None:
"""Remove message count entries that have not exceeded their defined
rate_hz limit
Expand Down
74 changes: 74 additions & 0 deletions tests/api/test_ratelimiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,77 @@ def consume_at(time: float) -> bool:

# Check that we get rate limited after using that token.
self.assertFalse(consume_at(11.1))

def test_record_action_which_doesnt_fill_bucket(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)

# Observe two actions, leaving room in the bucket for one more.
limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)

# We should be able to take a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
Comment on lines +327 to +329
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm struggling a little to understand how this differs from get_success(...) but it seems consistent with the other tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They look remarkably similar. AFAICS the only nontrivial difference is:

  • get_success_or_raise raises an exception with Failure.raiseException() if the deferred contains a Failure.
  • get_success (via successResultOf) will immediately fail the test.

Assuming that's correct, the two functions are essentially the same unless you're within a try block.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I figured it was some nuance like that but couldn't grok it immediately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely wasn't obvious to me!

self.assertTrue(success)

# ... but not two.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertFalse(success)

def test_record_action_which_fills_bucket(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)

# Observe three actions, filling up the bucket.
limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)

# We should be unable to take a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertFalse(success)

# If we wait 10 seconds to leak a token, we should be able to take one action...
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
)
self.assertTrue(success)

# ... but not two.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
)
self.assertFalse(success)

def test_record_action_which_overfills_bucket(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)

# Observe four actions, exceeding the bucket.
limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)

# We should be prevented from taking a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertFalse(success)

# If we wait 10 seconds to leak a token, we should be unable to take an action
# because the bucket is still full.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
)
self.assertFalse(success)

# But after another 10 seconds we leak a second token, giving us room for
# action.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
)
self.assertTrue(success)