diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 814febff3..414cae374 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,9 @@ Version v30.2.2 - We enabled API throttling for a basic user and for a staff user they can have unlimited access on API. +- We added throttle rate for each API endpoint and it can be + configured from the settings #991 https://github.com/nexB/vulnerablecode/issues/991. + Version v30.2.1 ---------------- diff --git a/vulnerabilities/api.py b/vulnerabilities/api.py index 5c949cda4..bb6398399 100644 --- a/vulnerabilities/api.py +++ b/vulnerabilities/api.py @@ -23,6 +23,7 @@ from vulnerabilities.models import VulnerabilityReference from vulnerabilities.models import VulnerabilitySeverity from vulnerabilities.models import get_purl_query_lookups +from vulnerabilities.throttling import StaffUserRateThrottle class VulnerabilitySeveritySerializer(serializers.ModelSerializer): @@ -220,9 +221,11 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = PackageSerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = PackageFilterSet + throttle_classes = [StaffUserRateThrottle] + throttle_scope = "packages" # TODO: Fix the swagger documentation for this endpoint - @action(detail=False, methods=["post"]) + @action(detail=False, methods=["post"], throttle_scope="bulk_search_packages") def bulk_search(self, request): """ See https://github.com/nexB/vulnerablecode/pull/369#issuecomment-796877606 for docs @@ -246,7 +249,7 @@ def bulk_search(self, request): if purl_data: purl_response = PackageSerializer(purl_data[0], context={"request": request}).data else: - purl_response = purl + purl_response = purl.to_dict() purl_response["unresolved_vulnerabilities"] = [] purl_response["resolved_vulnerabilities"] = [] purl_response["purl"] = purl_string @@ -254,7 +257,7 @@ def bulk_search(self, request): return Response(response) - @action(detail=False, methods=["get"]) + @action(detail=False, methods=["get"], throttle_scope="vulnerable_packages") def all(self, request): """ Return all the vulnerable Package URLs. @@ -302,6 +305,8 @@ def get_queryset(self): serializer_class = VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = VulnerabilityFilterSet + throttle_classes = [StaffUserRateThrottle] + throttle_scope = "vulnerabilities" class CPEFilterSet(filters.FilterSet): @@ -318,9 +323,11 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet): ).distinct() serializer_class = VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) + throttle_classes = [StaffUserRateThrottle] filterset_class = CPEFilterSet + throttle_scope = "cpes" - @action(detail=False, methods=["post"]) + @action(detail=False, methods=["post"], throttle_scope="bulk_search_cpes") def bulk_search(self, request): """ This endpoint is used to search for vulnerabilities by more than one CPE. @@ -357,3 +364,5 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = AliasFilterSet + throttle_classes = [StaffUserRateThrottle] + throttle_scope = "aliases" diff --git a/vulnerabilities/tests/test_throttling.py b/vulnerabilities/tests/test_throttling.py index ade4726ef..fe25137a0 100644 --- a/vulnerabilities/tests/test_throttling.py +++ b/vulnerabilities/tests/test_throttling.py @@ -7,6 +7,8 @@ # See https://aboutcode.org for more information about nexB OSS projects. # +import json + from django.contrib.auth import get_user_model from rest_framework.test import APIClient from rest_framework.test import APITestCase @@ -30,10 +32,10 @@ def setUp(self): self.staff_csrf_client = APIClient(enforce_csrf_checks=True) self.staff_csrf_client.credentials(HTTP_AUTHORIZATION=self.staff_auth) - def test_api_throttling(self): + def test_packages_endpoint_throttling(self): - # A basic user can only access API 5 times a day - for i in range(0, 5): + # A basic user can only access /packages endpoint 10 times a day + for i in range(0, 10): response = self.csrf_client.get("/api/packages") self.assertEqual(response.status_code, 200) response = self.staff_csrf_client.get("/api/packages") @@ -46,3 +48,123 @@ def test_api_throttling(self): response = self.staff_csrf_client.get("/api/packages", format="json") # 200 - staff user can access API unlimited times self.assertEqual(response.status_code, 200) + + def test_cpes_endpoint_throttling(self): + + # A basic user can only access /cpes endpoint 4 times a day + for i in range(0, 4): + response = self.csrf_client.get("/api/cpes") + self.assertEqual(response.status_code, 200) + response = self.staff_csrf_client.get("/api/cpes") + self.assertEqual(response.status_code, 200) + + response = self.csrf_client.get("/api/cpes") + # 429 - too many requests for basic user + self.assertEqual(response.status_code, 429) + + response = self.staff_csrf_client.get("/api/cpes", format="json") + # 200 - staff user can access API unlimited times + self.assertEqual(response.status_code, 200) + + def test_all_vulnerable_packages_endpoint_throttling(self): + + # A basic user can only access /packages/all 1 time a day + for i in range(0, 1): + response = self.csrf_client.get("/api/packages/all") + self.assertEqual(response.status_code, 200) + response = self.staff_csrf_client.get("/api/packages/all") + self.assertEqual(response.status_code, 200) + + response = self.csrf_client.get("/api/packages/all") + # 429 - too many requests for basic user + self.assertEqual(response.status_code, 429) + + response = self.staff_csrf_client.get("/api/packages/all", format="json") + # 200 - staff user can access API unlimited times + self.assertEqual(response.status_code, 200) + + def test_vulnerabilities_endpoint_throttling(self): + + # A basic user can only access /vulnerabilities 8 times a day + for i in range(0, 8): + response = self.csrf_client.get("/api/vulnerabilities") + self.assertEqual(response.status_code, 200) + response = self.staff_csrf_client.get("/api/vulnerabilities") + self.assertEqual(response.status_code, 200) + + response = self.csrf_client.get("/api/vulnerabilities") + # 429 - too many requests for basic user + self.assertEqual(response.status_code, 429) + + response = self.staff_csrf_client.get("/api/vulnerabilities", format="json") + # 200 - staff user can access API unlimited times + self.assertEqual(response.status_code, 200) + + def test_aliases_endpoint_throttling(self): + + # A basic user can only access /alias 2 times a day + for i in range(0, 2): + response = self.csrf_client.get("/api/alias") + self.assertEqual(response.status_code, 200) + response = self.staff_csrf_client.get("/api/alias") + self.assertEqual(response.status_code, 200) + + response = self.csrf_client.get("/api/alias") + # 429 - too many requests for basic user + self.assertEqual(response.status_code, 429) + + response = self.staff_csrf_client.get("/api/alias", format="json") + # 200 - staff user can access API unlimited times + self.assertEqual(response.status_code, 200) + + def test_bulk_search_packages_endpoint_throttling(self): + data = json.dumps({"purls": ["pkg:foo/bar"]}) + + # A basic user can only access /packages/bulk_search 6 times a day + for i in range(0, 6): + response = self.csrf_client.post( + "/api/packages/bulk_search", data=data, content_type="application/json" + ) + self.assertEqual(response.status_code, 200) + response = self.staff_csrf_client.post( + "/api/packages/bulk_search", data=data, content_type="application/json" + ) + self.assertEqual(response.status_code, 200) + + response = self.csrf_client.post( + "/api/packages/bulk_search", data=data, content_type="application/json" + ) + # 429 - too many requests for basic user + self.assertEqual(response.status_code, 429) + + response = self.staff_csrf_client.post( + "/api/packages/bulk_search", data=data, content_type="application/json" + ) + # 200 - staff user can access API unlimited times + self.assertEqual(response.status_code, 200) + + def test_bulk_search_cpes_endpoint_throttling(self): + data = json.dumps({"cpes": ["cpe:foo/bar"]}) + + # A basic user can only access /cpes/bulk_search 5 times a day + for i in range(0, 5): + response = self.csrf_client.post( + "/api/cpes/bulk_search", data=data, content_type="application/json" + ) + self.assertEqual(response.status_code, 200) + response = self.staff_csrf_client.post( + "/api/cpes/bulk_search", data=data, content_type="application/json" + ) + self.assertEqual(response.status_code, 200) + + response = self.csrf_client.post( + "/api/cpes/bulk_search", data=data, content_type="application/json" + ) + # 429 - too many requests for basic user + self.assertEqual(response.status_code, 429) + + response = self.staff_csrf_client.post( + "/api/cpes/bulk_search", data=data, content_type="application/json" + ) + # 200 - staff user can access API unlimited times + self.assertEqual(response.status_code, 200) diff --git a/vulnerabilities/throttling.py b/vulnerabilities/throttling.py index e98db3806..12fb23426 100644 --- a/vulnerabilities/throttling.py +++ b/vulnerabilities/throttling.py @@ -6,14 +6,10 @@ # See https://github.com/nexB/vulnerablecode for support or download. # See https://aboutcode.org for more information about nexB OSS projects. # +from rest_framework.throttling import ScopedRateThrottle -from django.contrib.auth import get_user_model -from rest_framework.throttling import UserRateThrottle -User = get_user_model() - - -class StaffUserRateThrottle(UserRateThrottle): +class StaffUserRateThrottle(ScopedRateThrottle): def allow_request(self, request, view): """ Do not apply throttling for superusers and admins. diff --git a/vulnerablecode/settings.py b/vulnerablecode/settings.py index 99b52a23a..60d74d707 100644 --- a/vulnerablecode/settings.py +++ b/vulnerablecode/settings.py @@ -150,11 +150,28 @@ LOGIN_REDIRECT_URL = "/" LOGOUT_REDIRECT_URL = "/" -THROTTLING_RATE = env.str("THROTTLING_RATE", default="1000/day") + +REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = { + "vulnerable_packages": "1/hour", + "bulk_search_packages": "5/hour", + "packages": "10/minute", + "vulnerabilities": "10/minute", + "aliases": "5/minute", + "cpes": "5/minute", + "bulk_search_cpes": "5/hour", +} if IS_TESTS: VULNERABLECODEIO_REQUIRE_AUTHENTICATION = True - THROTTLING_RATE = "5/day" + REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = { + "vulnerable_packages": "1/day", + "bulk_search_packages": "6/day", + "packages": "10/day", + "vulnerabilities": "8/day", + "aliases": "2/day", + "cpes": "4/day", + "bulk_search_cpes": "5/day", + } USE_L10N = True @@ -190,7 +207,7 @@ "DEFAULT_THROTTLE_CLASSES": [ "vulnerabilities.throttling.StaffUserRateThrottle", ], - "DEFAULT_THROTTLE_RATES": {"user": THROTTLING_RATE}, + "DEFAULT_THROTTLE_RATES": REST_FRAMEWORK_DEFAULT_THROTTLE_RATES, "DEFAULT_PAGINATION_CLASS": "vulnerabilities.pagination.SmallResultSetPagination", # Limit the load on the Database returning a small number of records by default. https://github.com/nexB/vulnerablecode/issues/819 "PAGE_SIZE": 10,