diff --git a/onadata/apps/api/tests/viewsets/test_data_viewset.py b/onadata/apps/api/tests/viewsets/test_data_viewset.py index 55be12b860..95ec149d22 100644 --- a/onadata/apps/api/tests/viewsets/test_data_viewset.py +++ b/onadata/apps/api/tests/viewsets/test_data_viewset.py @@ -2382,6 +2382,44 @@ def test_submission_count_for_day_tracking(self): cache.clear() inst_three.set_deleted() + def test_data_query_multiple_condition(self): + """ + Test that all conditions are met when a user queries for + data + """ + self._make_submissions() + view = DataViewSet.as_view({'get': 'list'}) + query_str = ( + '{"$or":[{"transport/loop_over_transport_types_frequency' + '/ambulance/frequency_to_referral_facility":"weekly","t' + 'ransport/loop_over_transport_types_frequency/ambulanc' + 'e/frequency_to_referral_facility":"daily"}]}' + ) + request = self.factory.get(f'/?query={query_str}', **self.extra) + response = view(request, pk=self.xform.pk) + count = 0 + + for inst in self.xform.instances.all(): + if inst.json.get( + 'transport/loop_over_transport_types_frequency' + '/ambulance/frequency_to_referral_facility' + ) in ['daily', 'weekly']: + count += 1 + + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), count) + + query_str = ( + '{"$or":[{"transport/loop_over_transport_types_frequency' + '/ambulance/frequency_to_referral_facility":"weekly"}, {"t' + 'ransport/loop_over_transport_types_frequency/ambulanc' + 'e/frequency_to_referral_facility":"daily"}]}' + ) + request = self.factory.get(f'/?query={query_str}', **self.extra) + response = view(request, pk=self.xform.pk) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), count) + def test_data_query_ornull(self): """ Test that a user is able to query for null with the diff --git a/onadata/apps/viewer/parsed_instance_tools.py b/onadata/apps/viewer/parsed_instance_tools.py index bc2d3a534c..f7b50a2c46 100644 --- a/onadata/apps/viewer/parsed_instance_tools.py +++ b/onadata/apps/viewer/parsed_instance_tools.py @@ -3,6 +3,7 @@ import datetime from builtins import str as text from future.utils import iteritems +from typing import Any, Tuple from onadata.libs.utils.common_tags import MONGO_STRFTIME, DATE_FORMAT @@ -82,6 +83,21 @@ def _parse_where(query, known_integers, known_decimals, or_where, or_params): return where + or_where, where_params + or_params +def _merge_duplicate_keys(pairs: Tuple[str, Any]): + ret = {} + + for field, value in pairs: + if not ret.get(field): + ret[field] = [] + ret[field].append(value) + + for key, value in ret.items(): + if len(value) == 1: + ret[key] = value[0] + + return ret + + def get_where_clause(query, form_integer_fields=None, form_decimal_fields=None): if form_integer_fields is None: @@ -95,7 +111,8 @@ def get_where_clause(query, form_integer_fields=None, try: if query and isinstance(query, (dict, six.string_types)): - query = query if isinstance(query, dict) else json.loads(query) + query = query if isinstance(query, dict) else json.loads( + query, object_pairs_hook=_merge_duplicate_keys) or_where = [] or_params = [] if isinstance(query, list): @@ -108,6 +125,10 @@ def get_where_clause(query, form_integer_fields=None, for k, v in or_query.items(): if v is None: or_where.extend([u"json->>'{}' IS NULL".format(k)]) + elif isinstance(v, list): + for value in v: + or_where.extend(["json->>%s = %s"]) + or_params.extend([k, value]) else: or_where.extend( [u"json->>%s = %s"])