diff --git a/onadata/apps/api/tests/viewsets/test_data_viewset.py b/onadata/apps/api/tests/viewsets/test_data_viewset.py index 293623c47b..9ce726b464 100644 --- a/onadata/apps/api/tests/viewsets/test_data_viewset.py +++ b/onadata/apps/api/tests/viewsets/test_data_viewset.py @@ -702,17 +702,15 @@ def test_paginate_and_sort_streaming_data(self): self.assertEqual(expected_order[3:], pg_2_items_in_order) @override_settings(STREAM_DATA=True) - def test_get_sorted_paginated_fields_in_streaming_data(self): + def test_get_paginated_fields_in_streaming_data(self): """ - Test that "sort" query param works as expected for paginated - responses + Test that we can return data for certain fields """ self._make_submissions() view = DataViewSet.as_view({'get': 'list'}) formid = self.xform.pk - # will result in a queryset due to the page and page_size params - # hence paging and thus len(self.object_list) for length + # will result in a generator due to the JSON sort query_data = { "page_size": 3, "page": 1, @@ -726,13 +724,11 @@ def test_get_sorted_paginated_fields_in_streaming_data(self): streaming_data = json.loads(''.join( [c.decode('utf-8') for c in response.streaming_content])) - self.assertEqual(len(streaming_data), 3) - # Test `date_created` field is sorted correctly - expected_order = [1, 3, 4] - items_in_order = [sub.get('_id') for sub in streaming_data] - - self.assertEqual(expected_order[:3], items_in_order) - self.assertTrue(response.has_header('ETag')) + self.assertEqual(len(streaming_data), 4) + # Test data returned is only the required fields + self.assertTrue(any('_id' in x for x in streaming_data)) + self.assertTrue(any( + '_submission_time' in x for x in streaming_data)) def test_data_start_limit_sort(self): self._make_submissions() @@ -846,47 +842,45 @@ def test_data_public(self): view = DataViewSet.as_view({'get': 'list'}) request = self.factory.get('/', **self.extra) response = view(request, pk='public') - self.assertEqual(response.status_code, 400) - self.assertEqual(response.get('Cache-Control'), None) - error_message = "Invalid form ID. It must be a positive integer" - self.assertEqual(str(response.data['detail']), error_message) - + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, []) self.xform.shared_data = True self.xform.save() + formid = self.xform.pk + data = _data_list(formid) response = view(request, pk='public') - self.assertEqual(response.status_code, 400) - self.assertEqual(response.get('Cache-Control'), None) - self.assertEqual(str(response.data['detail']), error_message) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, data) def test_data_public_anon_user(self): self._make_submissions() view = DataViewSet.as_view({'get': 'list'}) request = self.factory.get('/') response = view(request, pk='public') - self.assertEqual(response.status_code, 404) - error_message = "Not found." - self.assertEqual(str(response.data['detail']), error_message) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, []) self.xform.shared_data = True self.xform.save() + formid = self.xform.pk + data = _data_list(formid) response = view(request, pk='public') - self.assertEqual(response.status_code, 404) - self.assertEqual(str(response.data['detail']), error_message) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, data) def test_data_user_public(self): self._make_submissions() view = DataViewSet.as_view({'get': 'list'}) request = self.factory.get('/', **self.extra) response = view(request, pk='public') - self.assertEqual(response.status_code, 400) - self.assertEqual(response.get('Cache-Control'), None) - error_message = "Invalid form ID. It must be a positive integer" - self.assertEqual(str(response.data['detail']), error_message) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, []) self.xform.shared_data = True self.xform.save() + formid = self.xform.pk + data = _data_list(formid) response = view(request, pk='public') - self.assertEqual(response.status_code, 400) - self.assertEqual(response.get('Cache-Control'), None) - self.assertEqual(str(response.data['detail']), error_message) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, data) def test_data_bad_formid(self): self._make_submissions() diff --git a/onadata/apps/api/viewsets/data_viewset.py b/onadata/apps/api/viewsets/data_viewset.py index feb4e38cc4..f106524890 100644 --- a/onadata/apps/api/viewsets/data_viewset.py +++ b/onadata/apps/api/viewsets/data_viewset.py @@ -647,11 +647,10 @@ def _get_data(self, query, fields, sort, start, limit, is_public_request): if should_paginate: self.paginator.page_size = retrieval_threshold - if isinstance(self.object_list, types.GeneratorType) and \ - should_paginate: + if isinstance(self.object_list, types.GeneratorType): # Unpack generator object to list self.object_list = list(self.object_list) - else: + if should_paginate: current_page = query_param_keys.get( self.paginator.page_query_param, 1) current_page_size = query_param_keys.get( @@ -664,7 +663,6 @@ def _get_data(self, query, fields, sort, start, limit, is_public_request): current_page_size=current_page_size ) - if should_paginate: try: self.object_list = self.paginate_queryset(self.object_list) except OperationalError: