Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override throttle rate for each endpoint #993

Merged
merged 3 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Version v30.2.2
- We enabled API throttling for a basic user and for a staff user
they can have unlimited access on API.

- We added throttle rate for each API endpoint and it can be
configured from the settings #991 https://github.com/nexB/vulnerablecode/issues/991.


Version v30.2.1
----------------
Expand Down
27 changes: 26 additions & 1 deletion vulnerabilities/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
from vulnerabilities.models import VulnerabilityReference
from vulnerabilities.models import VulnerabilitySeverity
from vulnerabilities.models import get_purl_query_lookups
from vulnerabilities.throttling import AliasesAPIThrottle
from vulnerabilities.throttling import BulkSearchCPEAPIThrottle
from vulnerabilities.throttling import BulkSearchPackagesAPIThrottle
from vulnerabilities.throttling import CPEAPIThrottle
from vulnerabilities.throttling import PackagesAPIThrottle
from vulnerabilities.throttling import VulnerabilitiesAPIThrottle
from vulnerabilities.throttling import VulnerablePackagesAPIThrottle


class VulnerabilitySeveritySerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -221,6 +228,15 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet):
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = PackageFilterSet

def get_throttles(self):
if self.action == "bulk_search":
throttle_classes = [BulkSearchPackagesAPIThrottle]
elif self.action == "all":
throttle_classes = [VulnerablePackagesAPIThrottle]
else:
throttle_classes = [PackagesAPIThrottle]
return [throttle() for throttle in throttle_classes]

# TODO: Fix the swagger documentation for this endpoint
@action(detail=False, methods=["post"])
def bulk_search(self, request):
Expand All @@ -246,7 +262,7 @@ def bulk_search(self, request):
if purl_data:
purl_response = PackageSerializer(purl_data[0], context={"request": request}).data
else:
purl_response = purl
purl_response = purl.to_dict()
purl_response["unresolved_vulnerabilities"] = []
purl_response["resolved_vulnerabilities"] = []
purl_response["purl"] = purl_string
Expand Down Expand Up @@ -302,6 +318,7 @@ def get_queryset(self):
serializer_class = VulnerabilitySerializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = VulnerabilityFilterSet
throttle_classes = [VulnerabilitiesAPIThrottle]


class CPEFilterSet(filters.FilterSet):
Expand All @@ -320,6 +337,13 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet):
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = CPEFilterSet

def get_throttles(self):
if self.action == "bulk_search":
throttle_classes = [BulkSearchCPEAPIThrottle]
else:
throttle_classes = [CPEAPIThrottle]
return [throttle() for throttle in throttle_classes]

@action(detail=False, methods=["post"])
def bulk_search(self, request):
"""
Expand Down Expand Up @@ -357,3 +381,4 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = VulnerabilitySerializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = AliasFilterSet
throttle_classes = [AliasesAPIThrottle]
128 changes: 125 additions & 3 deletions vulnerabilities/tests/test_throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# See https://aboutcode.org for more information about nexB OSS projects.
#

import json

from django.contrib.auth import get_user_model
from rest_framework.test import APIClient
from rest_framework.test import APITestCase
Expand All @@ -30,10 +32,10 @@ def setUp(self):
self.staff_csrf_client = APIClient(enforce_csrf_checks=True)
self.staff_csrf_client.credentials(HTTP_AUTHORIZATION=self.staff_auth)

def test_api_throttling(self):
def test_packages_endpoint_throttling(self):

# A basic user can only access API 5 times a day
for i in range(0, 5):
# A basic user can only access /packages endpoint 10 times a day
for i in range(0, 10):
response = self.csrf_client.get("/api/packages")
self.assertEqual(response.status_code, 200)
response = self.staff_csrf_client.get("/api/packages")
Expand All @@ -46,3 +48,123 @@ def test_api_throttling(self):
response = self.staff_csrf_client.get("/api/packages", format="json")
# 200 - staff user can access API unlimited times
self.assertEqual(response.status_code, 200)

def test_cpes_endpoint_throttling(self):

# A basic user can only access /cpes endpoint 4 times a day
for i in range(0, 4):
response = self.csrf_client.get("/api/cpes")
self.assertEqual(response.status_code, 200)
response = self.staff_csrf_client.get("/api/cpes")
self.assertEqual(response.status_code, 200)

response = self.csrf_client.get("/api/cpes")
# 429 - too many requests for basic user
self.assertEqual(response.status_code, 429)

response = self.staff_csrf_client.get("/api/cpes", format="json")
# 200 - staff user can access API unlimited times
self.assertEqual(response.status_code, 200)

def test_all_vulnerable_packages_endpoint_throttling(self):

# A basic user can only access /packages/all 1 time a day
for i in range(0, 1):
response = self.csrf_client.get("/api/packages/all")
self.assertEqual(response.status_code, 200)
response = self.staff_csrf_client.get("/api/packages/all")
self.assertEqual(response.status_code, 200)

response = self.csrf_client.get("/api/packages/all")
# 429 - too many requests for basic user
self.assertEqual(response.status_code, 429)

response = self.staff_csrf_client.get("/api/packages/all", format="json")
# 200 - staff user can access API unlimited times
self.assertEqual(response.status_code, 200)

def test_vulnerabilities_endpoint_throttling(self):

# A basic user can only access /vulnerabilities 8 times a day
for i in range(0, 8):
response = self.csrf_client.get("/api/vulnerabilities")
self.assertEqual(response.status_code, 200)
response = self.staff_csrf_client.get("/api/vulnerabilities")
self.assertEqual(response.status_code, 200)

response = self.csrf_client.get("/api/vulnerabilities")
# 429 - too many requests for basic user
self.assertEqual(response.status_code, 429)

response = self.staff_csrf_client.get("/api/vulnerabilities", format="json")
# 200 - staff user can access API unlimited times
self.assertEqual(response.status_code, 200)

def test_aliases_endpoint_throttling(self):

# A basic user can only access /alias 2 times a day
for i in range(0, 2):
response = self.csrf_client.get("/api/alias")
self.assertEqual(response.status_code, 200)
response = self.staff_csrf_client.get("/api/alias")
self.assertEqual(response.status_code, 200)

response = self.csrf_client.get("/api/alias")
# 429 - too many requests for basic user
self.assertEqual(response.status_code, 429)

response = self.staff_csrf_client.get("/api/alias", format="json")
# 200 - staff user can access API unlimited times
self.assertEqual(response.status_code, 200)

def test_bulk_search_packages_endpoint_throttling(self):
data = json.dumps({"purls": ["pkg:foo/bar"]})

# A basic user can only access /packages/bulk_search 6 times a day
for i in range(0, 6):
response = self.csrf_client.post(
"/api/packages/bulk_search", data=data, content_type="application/json"
)
self.assertEqual(response.status_code, 200)
response = self.staff_csrf_client.post(
"/api/packages/bulk_search", data=data, content_type="application/json"
)
self.assertEqual(response.status_code, 200)

response = self.csrf_client.post(
"/api/packages/bulk_search", data=data, content_type="application/json"
)
# 429 - too many requests for basic user
self.assertEqual(response.status_code, 429)

response = self.staff_csrf_client.post(
"/api/packages/bulk_search", data=data, content_type="application/json"
)
# 200 - staff user can access API unlimited times
self.assertEqual(response.status_code, 200)

def test_bulk_search_cpes_endpoint_throttling(self):
data = json.dumps({"cpes": ["cpe:foo/bar"]})

# A basic user can only access /cpes/bulk_search 5 times a day
for i in range(0, 5):
response = self.csrf_client.post(
"/api/cpes/bulk_search", data=data, content_type="application/json"
)
self.assertEqual(response.status_code, 200)
response = self.staff_csrf_client.post(
"/api/cpes/bulk_search", data=data, content_type="application/json"
)
self.assertEqual(response.status_code, 200)

response = self.csrf_client.post(
"/api/cpes/bulk_search", data=data, content_type="application/json"
)
# 429 - too many requests for basic user
self.assertEqual(response.status_code, 429)

response = self.staff_csrf_client.post(
"/api/cpes/bulk_search", data=data, content_type="application/json"
)
# 200 - staff user can access API unlimited times
self.assertEqual(response.status_code, 200)
43 changes: 41 additions & 2 deletions vulnerabilities/throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#

from django.contrib.auth import get_user_model
from rest_framework.throttling import UserRateThrottle
from rest_framework.throttling import SimpleRateThrottle

User = get_user_model()


class StaffUserRateThrottle(UserRateThrottle):
class StaffUserRateThrottle(SimpleRateThrottle):
def allow_request(self, request, view):
"""
Do not apply throttling for superusers and admins.
Expand All @@ -22,3 +22,42 @@ def allow_request(self, request, view):
return True

return super().allow_request(request, view)

def get_cache_key(self, request, view):
"""
Return the cache key to use for this request.
"""
if request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)

return self.cache_format % {"scope": self.scope, "ident": ident}


class VulnerablePackagesAPIThrottle(StaffUserRateThrottle):
scope = "vulnerable_packages"


class BulkSearchPackagesAPIThrottle(StaffUserRateThrottle):
scope = "bulk_search_packages"


class PackagesAPIThrottle(StaffUserRateThrottle):
scope = "packages"


class VulnerabilitiesAPIThrottle(StaffUserRateThrottle):
scope = "vulnerabilities"


class AliasesAPIThrottle(StaffUserRateThrottle):
scope = "aliases"


class CPEAPIThrottle(StaffUserRateThrottle):
scope = "cpes"


class BulkSearchCPEAPIThrottle(StaffUserRateThrottle):
scope = "bulk_search_cpes"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no better way to do this? It looks like a lot of duplication here.

Copy link
Contributor Author

@TG1999 TG1999 Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now used ScopeRatedThrottling, please see if it looks good

26 changes: 23 additions & 3 deletions vulnerablecode/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,23 @@

LOGIN_REDIRECT_URL = "/"
LOGOUT_REDIRECT_URL = "/"
THROTTLING_RATE = env.str("THROTTLING_RATE", default="1000/day")
TEST_PACKAGE_THROTTLING_RATE = None
TEST_BULK_SEARCH_PACKAGE_THROTTLING_RATE = None
TEST_ALL_VULNERABLE_PACKAGE_THROTTLING_RATE = None
TEST_VULNERABILITIES_THROTTLING_RATE = None
TEST_CPES_THROTTLING_RATE = None
TEST_BULK_SEARCH_CPES_THROTTLING_RATE = None
TEST_ALIASES_THROTTLING_RATE = None
TG1999 marked this conversation as resolved.
Show resolved Hide resolved

if IS_TESTS:
VULNERABLECODEIO_REQUIRE_AUTHENTICATION = True
THROTTLING_RATE = "5/day"
TEST_PACKAGE_THROTTLING_RATE = "10/day"
TEST_BULK_SEARCH_PACKAGE_THROTTLING_RATE = "6/day"
TEST_ALL_VULNERABLE_PACKAGE_THROTTLING_RATE = "1/day"
TEST_VULNERABILITIES_THROTTLING_RATE = "8/day"
TEST_CPES_THROTTLING_RATE = "4/day"
TEST_BULK_SEARCH_CPES_THROTTLING_RATE = "5/day"
TEST_ALIASES_THROTTLING_RATE = "2/day"


USE_L10N = True
Expand Down Expand Up @@ -190,7 +202,15 @@
"DEFAULT_THROTTLE_CLASSES": [
"vulnerabilities.throttling.StaffUserRateThrottle",
],
"DEFAULT_THROTTLE_RATES": {"user": THROTTLING_RATE},
"DEFAULT_THROTTLE_RATES": {
"vulnerable_packages": TEST_ALL_VULNERABLE_PACKAGE_THROTTLING_RATE or "1/hour",
"bulk_search_packages": TEST_BULK_SEARCH_PACKAGE_THROTTLING_RATE or "5/hour",
"packages": TEST_PACKAGE_THROTTLING_RATE or "10/minute",
"vulnerabilities": TEST_VULNERABILITIES_THROTTLING_RATE or "10/minute",
"aliases": TEST_ALIASES_THROTTLING_RATE or "5/minute",
"cpes": TEST_CPES_THROTTLING_RATE or "5/minute",
"bulk_search_cpes": TEST_BULK_SEARCH_CPES_THROTTLING_RATE or "5/hour",
},
"DEFAULT_PAGINATION_CLASS": "vulnerabilities.pagination.SmallResultSetPagination",
# Limit the load on the Database returning a small number of records by default. https://github.com/nexB/vulnerablecode/issues/819
"PAGE_SIZE": 10,
Expand Down