From 1d25158508453710eb94e49af8aa297f1ec431b5 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Tue, 29 Mar 2022 13:56:42 -0700 Subject: [PATCH 1/3] Add to_json() and from_json() methods to Cohort class Signed-off-by: Gaurav Gupta --- raiwidgets/raiwidgets/cohort.py | 70 +++++++++++++++++++++++++++++++++ raiwidgets/tests/test_cohort.py | 33 +++++++++++----- 2 files changed, 93 insertions(+), 10 deletions(-) diff --git a/raiwidgets/raiwidgets/cohort.py b/raiwidgets/raiwidgets/cohort.py index 5e490cfa3a..bac20863db 100644 --- a/raiwidgets/raiwidgets/cohort.py +++ b/raiwidgets/raiwidgets/cohort.py @@ -3,6 +3,7 @@ """Module for defining cohorts in raiwidgets package.""" +import json from typing import Any, List, Optional import numpy as np @@ -397,6 +398,75 @@ def __init__(self, name: str): self.name = name self.cohort_filter_list = None + @staticmethod + def _cohort_serializer(obj): + """The function to serialize the Cohort class object. + + :param obj: Any member of the Cohort class object. + :type: Any + :return: Python dictionary. + :rtype: dict[Any, Any] + """ + return obj.__dict__ + + def to_json(self): + """Returns a serialized JSON string for the Cohort object. + + :return: The JSON string for the cohort. + :rtype: str + """ + return json.dumps(self, default=Cohort._cohort_serializer) + + @staticmethod + def _get_cohort_object(json_dict): + """Method to read a JSON dictionary and return a Cohort object. + + :param json_dict: JSON dictionary containing cohort data. + :type: dict[str, str] + :return: The Cohort object. + :rtype: Cohort + """ + if "name" not in json_dict: + raise UserConfigValidationException( + "No name field found for cohort deserialization") + if "cohort_filter_list" not in json_dict: + raise UserConfigValidationException( + "No cohort_filter_list field found for cohort deserialization") + if not isinstance(json_dict['cohort_filter_list'], list): + raise UserConfigValidationException( + "Field cohort_filter_list not of type list for " + "cohort deserialization") + + deserialized_cohort = Cohort(json_dict['name']) + for serialized_cohort_filter in json_dict['cohort_filter_list']: + if "method" not in serialized_cohort_filter: + raise UserConfigValidationException( + "Field method not found for cohort filter deserialization") + if "arg" not in serialized_cohort_filter: + raise UserConfigValidationException( + "Field arg not found for cohort filter deserialization") + if "column" not in serialized_cohort_filter: + raise UserConfigValidationException( + "Field column not found for cohort filter deserialization") + cohort_filter = CohortFilter( + method=serialized_cohort_filter['method'], + arg=serialized_cohort_filter['arg'], + column=serialized_cohort_filter['column']) + deserialized_cohort.add_cohort_filter(cohort_filter=cohort_filter) + return deserialized_cohort + + @staticmethod + def from_json(json_str): + """Method to read a json string and return a Cohort object. + + :param json_str: Serialized JSON string. + :type: str + :return: The Cohort object. + :rtype: Cohort + """ + json_dict = json.loads(json_str) + return Cohort._get_cohort_object(json_dict) + def add_cohort_filter(self, cohort_filter: CohortFilter): """Add a cohort filter into the cohort. :param cohort_filter: Cohort filter defined by CohortFilter class. diff --git a/raiwidgets/tests/test_cohort.py b/raiwidgets/tests/test_cohort.py index d47e2c8fea..6adc33fb61 100644 --- a/raiwidgets/tests/test_cohort.py +++ b/raiwidgets/tests/test_cohort.py @@ -497,40 +497,46 @@ def test_cohort_serialization_single_value_method(self, method): arg=[65], column='age') cohort_1 = Cohort(name="Cohort New") cohort_1.add_cohort_filter(cohort_filter_1) - json_str = json.dumps(cohort_1, - default=cohort_filter_json_converter) + json_str = cohort_1.to_json() assert 'Cohort New' in json_str assert method in json_str assert '[65]' in json_str assert 'age' in json_str - def test_cohort_serialization_in_range_method(self): + def test_cohort_serialization_deserialization_in_range_method(self): cohort_filter_1 = CohortFilter( method=CohortFilterMethods.METHOD_RANGE, arg=[65.0, 70.0], column='age') cohort_1 = Cohort(name="Cohort New") cohort_1.add_cohort_filter(cohort_filter_1) - json_str = json.dumps(cohort_1, - default=cohort_filter_json_converter) + json_str = cohort_1.to_json() assert 'Cohort New' in json_str assert CohortFilterMethods.METHOD_RANGE in json_str assert '65.0' in json_str assert '70.0' in json_str assert 'age' in json_str + cohort_1_new = Cohort.from_json(json_str) + assert cohort_1_new.name == cohort_1.name + assert len(cohort_1_new.cohort_filter_list) == \ + len(cohort_1.cohort_filter_list) + assert cohort_1_new.cohort_filter_list[0].method == \ + cohort_1.cohort_filter_list[0].method + @pytest.mark.parametrize('method', [CohortFilterMethods.METHOD_INCLUDES, CohortFilterMethods.METHOD_EXCLUDES]) - def test_cohort_serialization_include_exclude_methods(self, method): + def test_cohort_serialization_deserialization_include_exclude_methods( + self, method): cohort_filter_str = CohortFilter(method=method, arg=['val1', 'val2', 'val3'], column='age') cohort_str = Cohort(name="Cohort New Str") cohort_str.add_cohort_filter(cohort_filter_str) - json_str = json.dumps(cohort_str, - default=cohort_filter_json_converter) + + json_str = cohort_str.to_json() assert method in json_str assert 'val1' in json_str assert 'val2' in json_str @@ -542,14 +548,21 @@ def test_cohort_serialization_include_exclude_methods(self, method): column='age') cohort_int = Cohort(name="Cohort New Int") cohort_int.add_cohort_filter(cohort_filter_int) - json_str = json.dumps(cohort_filter_int, - default=cohort_filter_json_converter) + + json_str = cohort_int.to_json() assert method in json_str assert '1' in json_str assert '2' in json_str assert '3' in json_str assert 'age' in json_str + cohort_int_new = Cohort.from_json(json_str) + assert cohort_int_new.name == cohort_int.name + assert len(cohort_int_new.cohort_filter_list) == \ + len(cohort_int.cohort_filter_list) + assert cohort_int_new.cohort_filter_list[0].method == \ + cohort_int.cohort_filter_list[0].method + class TestCohortList: def test_cohort_list_serialization(self): From 180d6e5f604d400ed5b39b592c07ec68e52da9c8 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Fri, 1 Apr 2022 12:15:05 -0700 Subject: [PATCH 2/3] Address code review comments Signed-off-by: Gaurav Gupta --- raiwidgets/raiwidgets/cohort.py | 60 ++++++++++++++++++++++++--------- raiwidgets/tests/test_cohort.py | 49 ++++++++++++++++++++++++--- 2 files changed, 89 insertions(+), 20 deletions(-) diff --git a/raiwidgets/raiwidgets/cohort.py b/raiwidgets/raiwidgets/cohort.py index bac20863db..f800f78b70 100644 --- a/raiwidgets/raiwidgets/cohort.py +++ b/raiwidgets/raiwidgets/cohort.py @@ -111,6 +111,11 @@ def __init__(self, method: str, arg: List[Any], column: str): self.arg = arg self.column = column + def __eq__(self, cohort_filter: Any): + return self.method == cohort_filter.method and \ + self.arg == cohort_filter.arg and \ + self.column == cohort_filter.column + def _validate_cohort_filter_parameters( self, method: str, arg: List[Any], column: str): """Validate the input values for the cohort filter. @@ -398,6 +403,32 @@ def __init__(self, name: str): self.name = name self.cohort_filter_list = None + def __eq__(self, cohort: Any): + same_name = self.name == cohort.name + if self.cohort_filter_list is None and \ + cohort.cohort_filter_list is None: + return same_name + elif self.cohort_filter_list is not None and \ + cohort.cohort_filter_list is None: + return False + elif self.cohort_filter_list is None and \ + cohort.cohort_filter_list is not None: + return False + + same_num_cohort_filters = len(self.cohort_filter_list) == \ + len(cohort.cohort_filter_list) + if not same_num_cohort_filters: + return False + + same_cohort_filters = True + for index in range(0, len(self.cohort_filter_list)): + if self.cohort_filter_list[index] != \ + cohort.cohort_filter_list[index]: + same_cohort_filters = False + break + + return same_name and same_cohort_filters + @staticmethod def _cohort_serializer(obj): """The function to serialize the Cohort class object. @@ -426,12 +457,13 @@ def _get_cohort_object(json_dict): :return: The Cohort object. :rtype: Cohort """ - if "name" not in json_dict: - raise UserConfigValidationException( - "No name field found for cohort deserialization") - if "cohort_filter_list" not in json_dict: - raise UserConfigValidationException( - "No cohort_filter_list field found for cohort deserialization") + cohort_fields = ["name", "cohort_filter_list"] + for cohort_field in cohort_fields: + if cohort_field not in json_dict: + raise UserConfigValidationException( + "No {0} field found for cohort deserialization".format( + cohort_field)) + if not isinstance(json_dict['cohort_filter_list'], list): raise UserConfigValidationException( "Field cohort_filter_list not of type list for " @@ -439,15 +471,13 @@ def _get_cohort_object(json_dict): deserialized_cohort = Cohort(json_dict['name']) for serialized_cohort_filter in json_dict['cohort_filter_list']: - if "method" not in serialized_cohort_filter: - raise UserConfigValidationException( - "Field method not found for cohort filter deserialization") - if "arg" not in serialized_cohort_filter: - raise UserConfigValidationException( - "Field arg not found for cohort filter deserialization") - if "column" not in serialized_cohort_filter: - raise UserConfigValidationException( - "Field column not found for cohort filter deserialization") + cohort_filter_fields = ["method", "arg", "column"] + for cohort_filter_field in cohort_filter_fields: + if cohort_filter_field not in serialized_cohort_filter: + raise UserConfigValidationException( + "No {0} field found for cohort filter " + "deserialization".format(cohort_filter_field)) + cohort_filter = CohortFilter( method=serialized_cohort_filter['method'], arg=serialized_cohort_filter['arg'], diff --git a/raiwidgets/tests/test_cohort.py b/raiwidgets/tests/test_cohort.py index 6adc33fb61..2116b86fd4 100644 --- a/raiwidgets/tests/test_cohort.py +++ b/raiwidgets/tests/test_cohort.py @@ -542,6 +542,8 @@ def test_cohort_serialization_deserialization_include_exclude_methods( assert 'val2' in json_str assert 'val3' in json_str assert 'age' in json_str + cohort_str_new = Cohort.from_json(json_str) + assert cohort_str == cohort_str_new cohort_filter_int = CohortFilter(method=method, arg=[1, 2, 3], @@ -557,11 +559,48 @@ def test_cohort_serialization_deserialization_include_exclude_methods( assert 'age' in json_str cohort_int_new = Cohort.from_json(json_str) - assert cohort_int_new.name == cohort_int.name - assert len(cohort_int_new.cohort_filter_list) == \ - len(cohort_int.cohort_filter_list) - assert cohort_int_new.cohort_filter_list[0].method == \ - cohort_int.cohort_filter_list[0].method + assert cohort_int == cohort_int_new + + def test_cohort_deserialization_error_conditions(self): + test_dict = {} + with pytest.raises( + UserConfigValidationException, + match="No name field found for cohort deserialization"): + Cohort.from_json(json.dumps(test_dict)) + + test_dict = {'name': 'Cohort New'} + with pytest.raises( + UserConfigValidationException, + match="No cohort_filter_list field found for cohort deserialization"): + Cohort.from_json(json.dumps(test_dict)) + + test_dict = {'name': 'Cohort New', 'cohort_filter_list': {}} + with pytest.raises(UserConfigValidationException, + match="Field cohort_filter_list not of type list " + "for cohort deserialization"): + Cohort.from_json(json.dumps(test_dict)) + + test_dict = {'name': 'Cohort New', 'cohort_filter_list': [{}]} + with pytest.raises( + UserConfigValidationException, + match="No method field found for cohort filter deserialization"): + Cohort.from_json(json.dumps(test_dict)) + + test_dict = { + 'name': 'Cohort New', + 'cohort_filter_list': [{"method": "fake_method"}]} + with pytest.raises( + UserConfigValidationException, + match="No arg field found for cohort filter deserialization"): + Cohort.from_json(json.dumps(test_dict)) + + test_dict = { + 'name': 'Cohort New', + 'cohort_filter_list': [{"method": "fake_method", "arg": []}]} + with pytest.raises( + UserConfigValidationException, + match="No column field found for cohort filter deserialization"): + Cohort.from_json(json.dumps(test_dict)) class TestCohortList: From 59138b3e4d59cdc9ec47dac7ff468801b8dcbb33 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Fri, 1 Apr 2022 13:47:35 -0700 Subject: [PATCH 3/3] Fix linting Signed-off-by: Gaurav Gupta --- raiwidgets/tests/test_cohort.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/raiwidgets/tests/test_cohort.py b/raiwidgets/tests/test_cohort.py index 2116b86fd4..4f36af5c64 100644 --- a/raiwidgets/tests/test_cohort.py +++ b/raiwidgets/tests/test_cohort.py @@ -571,7 +571,8 @@ def test_cohort_deserialization_error_conditions(self): test_dict = {'name': 'Cohort New'} with pytest.raises( UserConfigValidationException, - match="No cohort_filter_list field found for cohort deserialization"): + match="No cohort_filter_list field found for " + "cohort deserialization"): Cohort.from_json(json.dumps(test_dict)) test_dict = {'name': 'Cohort New', 'cohort_filter_list': {}} @@ -583,7 +584,8 @@ def test_cohort_deserialization_error_conditions(self): test_dict = {'name': 'Cohort New', 'cohort_filter_list': [{}]} with pytest.raises( UserConfigValidationException, - match="No method field found for cohort filter deserialization"): + match="No method field found for cohort filter " + "deserialization"): Cohort.from_json(json.dumps(test_dict)) test_dict = { @@ -599,7 +601,8 @@ def test_cohort_deserialization_error_conditions(self): 'cohort_filter_list': [{"method": "fake_method", "arg": []}]} with pytest.raises( UserConfigValidationException, - match="No column field found for cohort filter deserialization"): + match="No column field found for cohort filter " + "deserialization"): Cohort.from_json(json.dumps(test_dict))