Skip to content

Commit

Permalink
Add CustomScopedRateThrolle throttling class
Browse files Browse the repository at this point in the history
This will throttle anonymous enketo users more fairly
  • Loading branch information
FrankApiyo committed Aug 30, 2024
1 parent 000b72d commit a00ed0a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
27 changes: 23 additions & 4 deletions onadata/libs/tests/test_throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,32 @@

from rest_framework.test import APIRequestFactory

from onadata.libs.throttle import RequestHeaderThrottle
from onadata.libs.throttle import RequestHeaderThrottle, CustomScopedRateThrottle

class CustomScopedRateThrottleTest(TestCase):
def setUp(self):
"""
Reset the cache so that no throttles will be active
"""
cache.clear()
self.factory = APIRequestFactory()
self.throttle = CustomScopedRateThrottle()

def test_anonymous_users_get_throttled_based_on_uri_path(self):
request = self.factory.get("/enketo/1234/submission")
self.throttle.scope = "submission"
cache_key = self.throttle.get_cache_key(request, None)
self.assertEqual(cache_key, "submission_/enketo/1234/submission")

def test_users_get_throttled_based_on_uri_path(self):
request = self.factory.get("/enketo/1234/submission")
request.user = self.user
self.throttle.scope = "submission"
cache_key = self.throttle.get_cache_key(request, None)
self.assertEqual(cache_key, f"submission_{self.user.id}")


class ThrottlingTests(TestCase):
"""
Test Renderer class.
"""

def setUp(self):
"""
Expand Down
15 changes: 14 additions & 1 deletion onadata/libs/throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,20 @@

from django.conf import settings

from rest_framework.throttling import SimpleRateThrottle
from rest_framework.throttling import SimpleRateThrottle, ScopedRateThrottle


class CustomScopedRateThrottle(ScopedRateThrottle):
"""
Custom throttling for fair throttling for anonymous users sharing IP
"""

def get_cache_key(self, request, view):
if request.user and request.user.is_authenticated:
return super().get_cache_key(request, view)

return self.cache_format % \
{ 'scope': self.scope,'ident': request.path }


class RequestHeaderThrottle(SimpleRateThrottle):
Expand Down

0 comments on commit a00ed0a

Please sign in to comment.