diff --git a/wagtail_ab_testing/test/tests/test_views.py b/wagtail_ab_testing/test/tests/test_views.py index 5ab697d..bfd19ca 100644 --- a/wagtail_ab_testing/test/tests/test_views.py +++ b/wagtail_ab_testing/test/tests/test_views.py @@ -1,8 +1,9 @@ import datetime +from django.contrib.auth import get_user_model from django.urls import reverse from freezegun import freeze_time -from rest_framework.test import APITestCase +from rest_framework.test import APIClient, APITestCase from wagtail.core.models import Page from wagtail_ab_testing.models import AbTest @@ -73,6 +74,26 @@ def test_register_participant_finish(self): self.ab_test.refresh_from_db() self.assertEqual(self.ab_test.status, AbTest.STATUS_FINISHED) + def test_register_participant_authenticated_user(self): + # By default, Django REST framework will enforce CSRF checks on authenticated users + # We disable these by removing all authentication/permission classes from the view + client = APIClient(enforce_csrf_checks=True) + + User = get_user_model() + User.objects.create_user('foo', 'myemail@test.com', 'bar') + client.login(username='foo', password='bar') + + response = client.post( + reverse('wagtail_ab_testing:register_participant'), + { + 'test_id': self.ab_test.id, + 'version': 'control', + } + ) + + # Shouldn't give 403 error + self.assertEqual(response.status_code, 200) + @freeze_time('2020-11-04T22:37:00Z') class TestGoalReached(APITestCase): @@ -146,3 +167,23 @@ def test_log_conversion_for_something_else(self): # This shouldn't create a history log self.assertFalse(self.ab_test.hourly_logs.exists()) + + def test_log_conversion_authenticated_user(self): + # By default, Django REST framework will enforce CSRF checks on authenticated users + # We disable these by removing all authentication/permission classes from the view + client = APIClient(enforce_csrf_checks=True) + + User = get_user_model() + User.objects.create_user('foo', 'myemail@test.com', 'bar') + client.login(username='foo', password='bar') + + response = client.post( + reverse('wagtail_ab_testing:goal_reached', args=[]), + { + 'test_id': self.ab_test.id, + 'version': 'control' + } + ) + + # Shouldn't give 403 error + self.assertEqual(response.status_code, 200) diff --git a/wagtail_ab_testing/views.py b/wagtail_ab_testing/views.py index a389ae2..21bf3de 100644 --- a/wagtail_ab_testing/views.py +++ b/wagtail_ab_testing/views.py @@ -14,7 +14,7 @@ import django_filters from django_filters.constants import EMPTY_VALUES from rest_framework import status -from rest_framework.decorators import api_view +from rest_framework.decorators import api_view, authentication_classes, permission_classes from rest_framework.response import Response from wagtail.admin import messages from wagtail.admin.action_menu import ActionMenuItem @@ -435,6 +435,8 @@ def get_queryset(self): @csrf_exempt @api_view(['POST']) +@authentication_classes([]) +@permission_classes([]) def register_participant(request): test_id = request.data.get('test_id', None) if test_id is None: @@ -467,6 +469,8 @@ def register_participant(request): @csrf_exempt @api_view(['POST']) +@authentication_classes([]) +@permission_classes([]) def goal_reached(request): test_id = request.data.get('test_id', None) if test_id is None: