diff --git a/onadata/libs/tests/utils/test_csv_import.py b/onadata/libs/tests/utils/test_csv_import.py index 0ca4495da4..70f12a6615 100644 --- a/onadata/libs/tests/utils/test_csv_import.py +++ b/onadata/libs/tests/utils/test_csv_import.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import json import os import re from builtins import open @@ -21,6 +22,7 @@ from onadata.libs.utils.common_tags import IMPORTED_VIA_CSV_BY from onadata.libs.utils.csv_import import get_submission_meta_dict from onadata.libs.utils.user_auth import get_user_default_project +from onadata.libs.utils.csv_import import get_columns_by_type def strip_xml_uuid(s): @@ -523,3 +525,31 @@ def test_csv_import_with_overwrite(self): self.assertEqual(count, 1) self.assertEqual(self.xform.num_of_submissions, 1) + + def test_get_columns_by_type(self): + """ + Test get_columns_by_type() returns columns in groups + """ + self.xls_file_path = os.path.join( + self.fixtures_dir, "form_with_multiple_select.xlsx" + ) + self._publish_xls_file(self.xls_file_path) + xform = XForm.objects.get() + columns = get_columns_by_type(["date"], json.loads(xform.json)) + self.assertEqual( + columns, ["section_A/date_of_survey", "section_B/year_established"] + ) + good_csv = open( + os.path.join( + self.fixtures_dir, "csv_import_with_multiple_select.csv" + ), "rb" + ) + csv_import.submit_csv(self.user.username, xform, good_csv) + self.assertEqual(Instance.objects.count(), 1) + submission = Instance.objects.first() + self.assertEqual(submission.status, "imported_via_csv") + self.assertEqual(submission.json["section_A/date_of_survey"], + "2015-09-10") + self.assertTrue( + submission.json["section_B/year_established"].startswith("1890") + ) diff --git a/onadata/libs/utils/csv_import.py b/onadata/libs/utils/csv_import.py index 211c83d8d0..f25a7cec52 100644 --- a/onadata/libs/utils/csv_import.py +++ b/onadata/libs/utils/csv_import.py @@ -611,10 +611,26 @@ def get_columns_by_type(type_list, form_json): within the type_list :rtype: list """ - return [ - dt.get('name') for dt in form_json.get('children') - if dt.get('type') in type_list - ] + + def _column_by_type(item_list, prefix=""): + found = [] + for item in item_list: + if item["type"] in ["group", "repeat"]: + prefix = "/".join( + [prefix, item["name"]] + ) if prefix else item["name"] + found.extend(_column_by_type(item["children"], prefix)) + prefix = "" # Reset prefix to blank + else: + if item["type"] in type_list: + name = "%s/%s" % ( + prefix, item["name"] + ) if prefix else item["name"] + found.append(name) + + return found + + return _column_by_type(form_json["children"]) def validate_row(row, columns):