From e257f061a41ee45b25398eba2cead559b0c255ac Mon Sep 17 00:00:00 2001 From: colin99d Date: Tue, 11 Apr 2023 11:32:33 -0400 Subject: [PATCH 01/10] Sending the request --- slowapi/extension.py | 4 +--- slowapi/wrappers.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/slowapi/extension.py b/slowapi/extension.py index 811d577..d31abe3 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -482,7 +482,7 @@ def __evaluate_limits( limit_for_header = None for lim in limits: limit_scope = lim.scope or endpoint - if lim.is_exempt: + if lim.is_exempt(request): continue if lim.methods is not None and request.method.lower() not in lim.methods: continue @@ -699,11 +699,9 @@ def decorator(func: Callable[..., Response]): else: self._route_limits.setdefault(name, []).extend(static_limits) - connection_type: Optional[str] = None sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": - connection_type = parameter.name break else: raise Exception( diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index a1741c5..54b0a31 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -2,6 +2,7 @@ from typing import Callable, Iterator, List, Optional, Union from limits import RateLimitItem, parse_many # type: ignore +from starlette.requests import Request class Limit(object): @@ -31,13 +32,18 @@ def __init__( self.cost = cost self.override_defaults = override_defaults - @property - def is_exempt(self) -> bool: + def is_exempt(self, request: Request) -> bool: """ Check if the limit is exempt. Return True to exempt the route from the limit. """ - return self.exempt_when() if self.exempt_when is not None else False + if self.exempt_when is None: + return False + params = inspect.signature(self.exempt_when).parameters + param_len = len(params) + if param_len == 1: + return self.exempt_when(request) + return self.exempt_when() @property def scope(self) -> str: From d5eb0b4ba007ffc8fe9f6beb2fda45fce99e1946 Mon Sep 17 00:00:00 2001 From: colin99d Date: Tue, 11 Apr 2023 16:09:42 -0400 Subject: [PATCH 02/10] Added rate limiting --- slowapi/wrappers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index 54b0a31..15b908b 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -35,6 +35,10 @@ def __init__( def is_exempt(self, request: Request) -> bool: """ Check if the limit is exempt. + + ** parameter ** + * **request**: the request object + Return True to exempt the route from the limit. """ if self.exempt_when is None: From 1167752db7945b2f5bded5f89b3272910685345d Mon Sep 17 00:00:00 2001 From: Colin Delahunty <72827203+colin99d@users.noreply.github.com> Date: Wed, 12 Jul 2023 10:17:52 -0400 Subject: [PATCH 03/10] Update wrappers.py --- slowapi/wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index 15b908b..c5f2972 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -32,7 +32,7 @@ def __init__( self.cost = cost self.override_defaults = override_defaults - def is_exempt(self, request: Request) -> bool: + def is_exempt(self, request: Optional[Request] = None) -> bool: """ Check if the limit is exempt. @@ -45,7 +45,7 @@ def is_exempt(self, request: Request) -> bool: return False params = inspect.signature(self.exempt_when).parameters param_len = len(params) - if param_len == 1: + if param_len == 1 and request: return self.exempt_when(request) return self.exempt_when() From 93046a00c9a00ab11bfece920f37795004122171 Mon Sep 17 00:00:00 2001 From: Laurent Savaete Date: Mon, 21 Aug 2023 16:44:19 +0200 Subject: [PATCH 04/10] Apply suggestions from code review Co-authored-by: Reuben Thomas-Davis --- slowapi/wrappers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index c5f2972..4c39744 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -29,6 +29,7 @@ def __init__( self.methods = methods self.error_message = error_message self.exempt_when = exempt_when + self._exempt_when_takes_request = len(inspect.signature(self.exempt_when).parameters) == 1 self.cost = cost self.override_defaults = override_defaults @@ -43,9 +44,7 @@ def is_exempt(self, request: Optional[Request] = None) -> bool: """ if self.exempt_when is None: return False - params = inspect.signature(self.exempt_when).parameters - param_len = len(params) - if param_len == 1 and request: + if self._exempt_when_takes_request and request: return self.exempt_when(request) return self.exempt_when() From fa53f193b0bb334e30cd3d46974bd47becc98cf6 Mon Sep 17 00:00:00 2001 From: colin99d Date: Thu, 5 Oct 2023 08:35:00 -0400 Subject: [PATCH 05/10] Black formatting --- slowapi/wrappers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index 4c39744..a149f9b 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -29,7 +29,9 @@ def __init__( self.methods = methods self.error_message = error_message self.exempt_when = exempt_when - self._exempt_when_takes_request = len(inspect.signature(self.exempt_when).parameters) == 1 + self._exempt_when_takes_request = ( + len(inspect.signature(self.exempt_when).parameters) == 1 + ) self.cost = cost self.override_defaults = override_defaults From acc5c879b99c858882d1a48f3fa000c8da3a99db Mon Sep 17 00:00:00 2001 From: colin99d Date: Fri, 5 Jan 2024 12:11:04 -0500 Subject: [PATCH 06/10] Mypy fix --- slowapi/wrappers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index a149f9b..d5677d4 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -30,7 +30,8 @@ def __init__( self.error_message = error_message self.exempt_when = exempt_when self._exempt_when_takes_request = ( - len(inspect.signature(self.exempt_when).parameters) == 1 + self.exempt_when + and len(inspect.signature(self.exempt_when).parameters) == 1 ) self.cost = cost self.override_defaults = override_defaults From 3488ec5c3c72f13f1fdf9497317a336cb8c10496 Mon Sep 17 00:00:00 2001 From: Colin Delahunty <72827203+colin99d@users.noreply.github.com> Date: Tue, 4 Jun 2024 19:12:03 +0000 Subject: [PATCH 07/10] Fixed --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc3d761..eaaf869 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## [0.1.10] - 2024-06-04 + +### Changed + +- Breaking: allow usage of the request object in the except_when function (thanks @colin99d) + ## [0.1.9] - 2024-02-05 ### Added From c1681754cb9c7fc0a6365174af4e5bac75354b58 Mon Sep 17 00:00:00 2001 From: Colin Delahunty <72827203+colin99d@users.noreply.github.com> Date: Tue, 4 Jun 2024 20:13:16 +0100 Subject: [PATCH 08/10] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eaaf869..e650180 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### Changed -- Breaking: allow usage of the request object in the except_when function (thanks @colin99d) +- Breaking change: allow usage of the request object in the except_when function (thanks @colin99d) ## [0.1.9] - 2024-02-05 From 57223b72d764f54dc5d394a67c76ad87ed8ae75b Mon Sep 17 00:00:00 2001 From: Colin Delahunty Date: Thu, 27 Jun 2024 12:47:12 -0400 Subject: [PATCH 09/10] Improved logs, added a test --- slowapi/extension.py | 9 +++-- tests/test_starlette_extension.py | 55 +++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/slowapi/extension.py b/slowapi/extension.py index e35d3cc..488f0d9 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -1,6 +1,7 @@ """ The starlette extension to rate-limit requests """ + import asyncio import functools import inspect @@ -734,7 +735,8 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Response: if not isinstance(response, Response): # get the response object from the decorated endpoint function self._inject_headers( - kwargs.get("response"), request.state.view_rate_limit # type: ignore + kwargs.get("response"), + request.state.view_rate_limit, # type: ignore ) else: self._inject_headers( @@ -766,7 +768,8 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Response: if not isinstance(response, Response): # get the response object from the decorated endpoint function self._inject_headers( - kwargs.get("response"), request.state.view_rate_limit # type: ignore + kwargs.get("response"), + request.state.view_rate_limit, # type: ignore ) else: self._inject_headers( @@ -803,7 +806,7 @@ def limit( * **error_message**: string (or callable that returns one) to override the error message used in the response. * **exempt_when**: function returning a boolean indicating whether to exempt - the route from the limit + the route from the limit. This function can optionally use a Request object. * **cost**: integer (or callable that returns one) which is the cost of a hit * **override_defaults**: whether to override the default limits (default: True) """ diff --git a/tests/test_starlette_extension.py b/tests/test_starlette_extension.py index 7f21c1d..0e26baa 100644 --- a/tests/test_starlette_extension.py +++ b/tests/test_starlette_extension.py @@ -43,6 +43,61 @@ def t1(request: Request): if i < 5: assert response.text == "test" + def test_exempt_when_argument(self, build_starlette_app): + app, limiter = build_starlette_app(key_func=get_ipaddr) + + def return_true(): + return True + + def return_false(): + return False + + def dynamic(request: Request): + user_agent = request.headers.get("User-Agent") + if user_agent is None: + return False + return user_agent == "exempt" + + @limiter.limit("1/minute", exempt_when=return_true) + def always_true(request: Request): + return PlainTextResponse("test") + + @limiter.limit("1/minute", exempt_when=return_false) + def always_false(request: Request): + return PlainTextResponse("test") + + @limiter.limit("1/minute", exempt_when=dynamic) + def always_dynamic(request: Request): + return PlainTextResponse("test") + + app.add_route("/true", always_true) + app.add_route("/false", always_false) + app.add_route("/dynamic", always_dynamic) + + client = TestClient(app) + # Test always true always exempting + for i in range(0, 2): + response = client.get("/true") + assert response.status_code == 200 + assert response.text == "test" + # Test always false hitting the limit after one hit + for i in range(0, 2): + response = client.get("/false") + assert response.status_code == 200 if i < 1 else 429 + if i < 1: + assert response.text == "test" + # Test dynamic not exempting with the correct header + for i in range(0, 2): + response = client.get("/dynamic", headers={"User-Agent": "exempt"}) + assert response.status_code == 200 + assert response.text == "test" + # Test dynamic exempting with the incorrect header + for i in range(0, 2): + response = client.get("/dynamic") + assert response.status_code == 200 if i < 1 else 429 + if i < 1: + assert response.text == "test" + def test_shared_decorator(self, build_starlette_app): app, limiter = build_starlette_app(key_func=get_ipaddr) From 91145c087d658b6c3404935cd02155e291c2682b Mon Sep 17 00:00:00 2001 From: Colin Delahunty Date: Thu, 27 Jun 2024 13:06:02 -0400 Subject: [PATCH 10/10] Fixed mypy error --- slowapi/extension.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/slowapi/extension.py b/slowapi/extension.py index 488f0d9..344a16f 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -735,8 +735,8 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Response: if not isinstance(response, Response): # get the response object from the decorated endpoint function self._inject_headers( - kwargs.get("response"), - request.state.view_rate_limit, # type: ignore + kwargs.get("response"), # type: ignore + request.state.view_rate_limit, ) else: self._inject_headers(