Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow Requests to be sent to exempt_when #160

Merged
merged 13 commits into from
Jun 27, 2024
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Change Log

## [0.1.10] - 2024-06-04

### Changed

- Breaking change: allow usage of the request object in the except_when function (thanks @colin99d)

## [0.1.9] - 2024-02-05

### Added
Expand Down
13 changes: 7 additions & 6 deletions slowapi/extension.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
The starlette extension to rate-limit requests
"""

import asyncio
import functools
import inspect
Expand Down Expand Up @@ -486,7 +487,7 @@ def __evaluate_limits(
limit_for_header = None
for lim in limits:
limit_scope = lim.scope or endpoint
if lim.is_exempt:
if lim.is_exempt(request):
continue
if lim.methods is not None and request.method.lower() not in lim.methods:
continue
Expand Down Expand Up @@ -703,11 +704,9 @@ def decorator(func: Callable[..., Response]):
else:
self._route_limits.setdefault(name, []).extend(static_limits)

connection_type: Optional[str] = None
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
connection_type = parameter.name
break
else:
raise Exception(
Expand Down Expand Up @@ -736,7 +735,8 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Response:
if not isinstance(response, Response):
# get the response object from the decorated endpoint function
self._inject_headers(
kwargs.get("response"), request.state.view_rate_limit # type: ignore
kwargs.get("response"), # type: ignore
request.state.view_rate_limit,
)
else:
self._inject_headers(
Expand Down Expand Up @@ -768,7 +768,8 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Response:
if not isinstance(response, Response):
# get the response object from the decorated endpoint function
self._inject_headers(
kwargs.get("response"), request.state.view_rate_limit # type: ignore
kwargs.get("response"),
request.state.view_rate_limit, # type: ignore
)
else:
self._inject_headers(
Expand Down Expand Up @@ -805,7 +806,7 @@ def limit(
* **error_message**: string (or callable that returns one) to override the
error message used in the response.
* **exempt_when**: function returning a boolean indicating whether to exempt
the route from the limit
the route from the limit. This function can optionally use a Request object.
* **cost**: integer (or callable that returns one) which is the cost of a hit
* **override_defaults**: whether to override the default limits (default: True)
"""
Expand Down
18 changes: 15 additions & 3 deletions slowapi/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Iterator, List, Optional, Union

from limits import RateLimitItem, parse_many # type: ignore
from starlette.requests import Request


class Limit(object):
Expand All @@ -28,16 +29,27 @@ def __init__(
self.methods = methods
self.error_message = error_message
self.exempt_when = exempt_when
self._exempt_when_takes_request = (
self.exempt_when
and len(inspect.signature(self.exempt_when).parameters) == 1
)
self.cost = cost
laurentS marked this conversation as resolved.
Show resolved Hide resolved
self.override_defaults = override_defaults

@property
def is_exempt(self) -> bool:
def is_exempt(self, request: Optional[Request] = None) -> bool:
"""
Check if the limit is exempt.

** parameter **
* **request**: the request object

Return True to exempt the route from the limit.
"""
return self.exempt_when() if self.exempt_when is not None else False
if self.exempt_when is None:
return False
if self._exempt_when_takes_request and request:
return self.exempt_when(request)
return self.exempt_when()

@property
def scope(self) -> str:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_starlette_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,61 @@ def t1(request: Request):
if i < 5:
assert response.text == "test"

def test_exempt_when_argument(self, build_starlette_app):
app, limiter = build_starlette_app(key_func=get_ipaddr)

def return_true():
return True

def return_false():
return False

def dynamic(request: Request):
user_agent = request.headers.get("User-Agent")
if user_agent is None:
return False
return user_agent == "exempt"

@limiter.limit("1/minute", exempt_when=return_true)
def always_true(request: Request):
return PlainTextResponse("test")

@limiter.limit("1/minute", exempt_when=return_false)
def always_false(request: Request):
return PlainTextResponse("test")

@limiter.limit("1/minute", exempt_when=dynamic)
def always_dynamic(request: Request):
return PlainTextResponse("test")

app.add_route("/true", always_true)
app.add_route("/false", always_false)
app.add_route("/dynamic", always_dynamic)

client = TestClient(app)
# Test always true always exempting
for i in range(0, 2):
response = client.get("/true")
assert response.status_code == 200
assert response.text == "test"
# Test always false hitting the limit after one hit
for i in range(0, 2):
response = client.get("/false")
assert response.status_code == 200 if i < 1 else 429
if i < 1:
assert response.text == "test"
# Test dynamic not exempting with the correct header
for i in range(0, 2):
response = client.get("/dynamic", headers={"User-Agent": "exempt"})
assert response.status_code == 200
assert response.text == "test"
# Test dynamic exempting with the incorrect header
for i in range(0, 2):
response = client.get("/dynamic")
assert response.status_code == 200 if i < 1 else 429
if i < 1:
assert response.text == "test"

def test_shared_decorator(self, build_starlette_app):
app, limiter = build_starlette_app(key_func=get_ipaddr)

Expand Down
Loading