Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_json() and from_json() methods to Cohort class #1300

Merged
merged 6 commits into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions raiwidgets/raiwidgets/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Module for defining cohorts in raiwidgets package."""

import json
from typing import Any, List, Optional

import numpy as np
Expand Down Expand Up @@ -397,6 +398,75 @@ def __init__(self, name: str):
self.name = name
self.cohort_filter_list = None

@staticmethod
def _cohort_serializer(obj):
"""The function to serialize the Cohort class object.

:param obj: Any member of the Cohort class object.
:type: Any
:return: Python dictionary.
:rtype: dict[Any, Any]
"""
return obj.__dict__

def to_json(self):
"""Returns a serialized JSON string for the Cohort object.

:return: The JSON string for the cohort.
:rtype: str
"""
return json.dumps(self, default=Cohort._cohort_serializer)
gaugup marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _get_cohort_object(json_dict):
"""Method to read a JSON dictionary and return a Cohort object.

:param json_dict: JSON dictionary containing cohort data.
:type: dict[str, str]
:return: The Cohort object.
:rtype: Cohort
"""
if "name" not in json_dict:
raise UserConfigValidationException(
"No name field found for cohort deserialization")
if "cohort_filter_list" not in json_dict:
gaugup marked this conversation as resolved.
Show resolved Hide resolved
raise UserConfigValidationException(
"No cohort_filter_list field found for cohort deserialization")
if not isinstance(json_dict['cohort_filter_list'], list):
raise UserConfigValidationException(
"Field cohort_filter_list not of type list for "
"cohort deserialization")

deserialized_cohort = Cohort(json_dict['name'])
for serialized_cohort_filter in json_dict['cohort_filter_list']:
if "method" not in serialized_cohort_filter:
gaugup marked this conversation as resolved.
Show resolved Hide resolved
raise UserConfigValidationException(
"Field method not found for cohort filter deserialization")
if "arg" not in serialized_cohort_filter:
raise UserConfigValidationException(
"Field arg not found for cohort filter deserialization")
if "column" not in serialized_cohort_filter:
raise UserConfigValidationException(
"Field column not found for cohort filter deserialization")
cohort_filter = CohortFilter(
method=serialized_cohort_filter['method'],
arg=serialized_cohort_filter['arg'],
column=serialized_cohort_filter['column'])
deserialized_cohort.add_cohort_filter(cohort_filter=cohort_filter)
return deserialized_cohort

@staticmethod
def from_json(json_str):
"""Method to read a json string and return a Cohort object.

:param json_str: Serialized JSON string.
:type: str
:return: The Cohort object.
:rtype: Cohort
"""
json_dict = json.loads(json_str)
return Cohort._get_cohort_object(json_dict)

def add_cohort_filter(self, cohort_filter: CohortFilter):
"""Add a cohort filter into the cohort.
:param cohort_filter: Cohort filter defined by CohortFilter class.
Expand Down
33 changes: 23 additions & 10 deletions raiwidgets/tests/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,40 +497,46 @@ def test_cohort_serialization_single_value_method(self, method):
arg=[65], column='age')
cohort_1 = Cohort(name="Cohort New")
cohort_1.add_cohort_filter(cohort_filter_1)
json_str = json.dumps(cohort_1,
default=cohort_filter_json_converter)
json_str = cohort_1.to_json()

assert 'Cohort New' in json_str
assert method in json_str
assert '[65]' in json_str
assert 'age' in json_str

def test_cohort_serialization_in_range_method(self):
def test_cohort_serialization_deserialization_in_range_method(self):
gaugup marked this conversation as resolved.
Show resolved Hide resolved
cohort_filter_1 = CohortFilter(
method=CohortFilterMethods.METHOD_RANGE,
arg=[65.0, 70.0], column='age')
cohort_1 = Cohort(name="Cohort New")
cohort_1.add_cohort_filter(cohort_filter_1)
json_str = json.dumps(cohort_1,
default=cohort_filter_json_converter)

json_str = cohort_1.to_json()
assert 'Cohort New' in json_str
assert CohortFilterMethods.METHOD_RANGE in json_str
assert '65.0' in json_str
assert '70.0' in json_str
assert 'age' in json_str

cohort_1_new = Cohort.from_json(json_str)
assert cohort_1_new.name == cohort_1.name
assert len(cohort_1_new.cohort_filter_list) == \
len(cohort_1.cohort_filter_list)
assert cohort_1_new.cohort_filter_list[0].method == \
cohort_1.cohort_filter_list[0].method
gaugup marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize('method',
[CohortFilterMethods.METHOD_INCLUDES,
CohortFilterMethods.METHOD_EXCLUDES])
def test_cohort_serialization_include_exclude_methods(self, method):
def test_cohort_serialization_deserialization_include_exclude_methods(
self, method):
cohort_filter_str = CohortFilter(method=method,
arg=['val1', 'val2', 'val3'],
column='age')
cohort_str = Cohort(name="Cohort New Str")
cohort_str.add_cohort_filter(cohort_filter_str)
json_str = json.dumps(cohort_str,
default=cohort_filter_json_converter)

json_str = cohort_str.to_json()
assert method in json_str
assert 'val1' in json_str
assert 'val2' in json_str
Expand All @@ -542,14 +548,21 @@ def test_cohort_serialization_include_exclude_methods(self, method):
column='age')
cohort_int = Cohort(name="Cohort New Int")
cohort_int.add_cohort_filter(cohort_filter_int)
json_str = json.dumps(cohort_filter_int,
default=cohort_filter_json_converter)

json_str = cohort_int.to_json()
assert method in json_str
assert '1' in json_str
assert '2' in json_str
assert '3' in json_str
assert 'age' in json_str

cohort_int_new = Cohort.from_json(json_str)
assert cohort_int_new.name == cohort_int.name
gaugup marked this conversation as resolved.
Show resolved Hide resolved
assert len(cohort_int_new.cohort_filter_list) == \
len(cohort_int.cohort_filter_list)
assert cohort_int_new.cohort_filter_list[0].method == \
cohort_int.cohort_filter_list[0].method


class TestCohortList:
def test_cohort_list_serialization(self):
Expand Down