diff --git a/docs/media.rst b/docs/media.rst index d08844df32..02b48be133 100644 --- a/docs/media.rst +++ b/docs/media.rst @@ -185,4 +185,24 @@ Response http://api.ona.io/api/v1/media/1.jpg - +Retrieve attachment count for a form +------------------------------------ +Returns the total number of attachments for a form + +.. raw:: html + +
GET /api/v1/media/count?xform={xform_id}
+ +Example +^^^^^^^ +:: + + + curl -X GET https://api.ona.io/api/v1/media/count?xform=1 + +Response +^^^^^^^^ +:: + + + {"count": 1} diff --git a/onadata/apps/api/tests/viewsets/test_attachment_viewset.py b/onadata/apps/api/tests/viewsets/test_attachment_viewset.py index b4bf322be9..40a0f18097 100644 --- a/onadata/apps/api/tests/viewsets/test_attachment_viewset.py +++ b/onadata/apps/api/tests/viewsets/test_attachment_viewset.py @@ -28,6 +28,9 @@ def setUp(self): self.list_view = AttachmentViewSet.as_view({ 'get': 'list' }) + self.count_view = AttachmentViewSet.as_view({ + 'get': 'count' + }) self._publish_xls_form_to_project() @@ -371,3 +374,11 @@ def test_direct_image_link_uppercase(self): self.assertEqual(response.status_code, 200) self.assertTrue(isinstance(response.data, basestring)) self.assertEqual(response.data, attachment_url(self.attachment)) + + def test_total_count(self): + self._submit_transport_instance_w_attachment() + xform_id = self.attachment.instance.xform.id + request = self.factory.get( + '/count', data={"xform": xform_id}, **self.extra) + response = self.count_view(request) + self.assertEqual(response.data['count'], 1) diff --git a/onadata/apps/api/viewsets/attachment_viewset.py b/onadata/apps/api/viewsets/attachment_viewset.py index 40a45a3d27..0d63153ca6 100644 --- a/onadata/apps/api/viewsets/attachment_viewset.py +++ b/onadata/apps/api/viewsets/attachment_viewset.py @@ -6,6 +6,7 @@ from django.conf import settings from rest_framework import renderers from rest_framework import viewsets +from rest_framework.decorators import action from rest_framework.exceptions import ParseError from rest_framework.response import Response @@ -83,6 +84,14 @@ def retrieve(self, request, *args, **kwargs): return Response(serializer.data) + @action(methods=['GET'], detail=False) + def count(self, request, *args, **kwargs): + data = { + "count": self.filter_queryset(self.get_queryset()).count() + } + + return Response(data=data) + def list(self, request, *args, **kwargs): if request.user.is_anonymous: xform = request.query_params.get('xform')