diff --git a/raiwidgets/raiwidgets/_cohort.py b/raiwidgets/raiwidgets/_cohort.py index e0dd8dba89..5e490cfa3a 100644 --- a/raiwidgets/raiwidgets/_cohort.py +++ b/raiwidgets/raiwidgets/_cohort.py @@ -3,7 +3,10 @@ """Module for defining cohorts in raiwidgets package.""" -from typing import Any, List +from typing import Any, List, Optional + +import numpy as np +import pandas as pd from responsibleai.exceptions import UserConfigValidationException @@ -39,9 +42,20 @@ class CohortFilterMethods: METHOD_EQUAL] +class ClassificationOutcomes: + """Defines the possible values for classification outcomes. + """ + FALSE_NEGATIVE = 'False negative' + FALSE_POSITIVE = 'False positive' + TRUE_NEGATIVE = 'True negative' + TRUE_POSITIVE = 'True positive' + + ALL = [FALSE_NEGATIVE, FALSE_POSITIVE, + TRUE_NEGATIVE, TRUE_POSITIVE] + + def cohort_filter_json_converter(obj): """Helper function to convert CohortFilter object to json. - :param obj: Object to convert to json. :type obj: object :return: The converted json. @@ -57,7 +71,6 @@ def cohort_filter_json_converter(obj): class CohortFilter: """Defines the cohort filter. - :param method: Cohort filter method from one of CohortFilterMethods. :type method: str :param arg: List of values to be used by the cohort filter. @@ -66,9 +79,20 @@ class CohortFilter: will be applied. :type column: str """ + PREDICTED_Y = 'Predicted Y' + TRUE_Y = 'True Y' + INDEX = 'Index' + CLASSIFICATION_OUTCOME = 'Classification outcome' + REGRESSION_ERROR = 'Error' + + SPECIAL_COLUMN_LIST = [INDEX, + PREDICTED_Y, + TRUE_Y, + CLASSIFICATION_OUTCOME, + REGRESSION_ERROR] + def __init__(self, method: str, arg: List[Any], column: str): """Defines the cohort filter. - :param method: Cohort filter method from one of CohortFilterMethods. :type method: str :param arg: List of values to be used by the cohort filter. @@ -89,7 +113,6 @@ def __init__(self, method: str, arg: List[Any], column: str): def _validate_cohort_filter_parameters( self, method: str, arg: List[Any], column: str): """Validate the input values for the cohort filter. - :param method: Cohort filter method from one of CohortFilterMethods. :type method: str :param arg: List of values to be used by the cohort filter. @@ -97,7 +120,27 @@ def _validate_cohort_filter_parameters( :param column: The column name from the dataset on which the filter will be applied. :type column: str + + The following validations can be performed on the cohort filter:- + + 1. Verify the correct types for method (expected string), column + (expected string) and arg (expected list). + 2. The method value should be one of the filter string from + CohortFilterMethods.ALL. + 3. The arg shouldn't be an empty list. + 4. For all cohort filter methods in + CohortFilterMethods.SINGLE_VALUE_METHODS, the value in the arg + should be integer or float and there should be only one value + in arg. + 5. For cohort filter method CohortFilterMethods.METHOD_RANGE, + the values in the arg should be integer or float and there + should be only two values in arg. """ + if not isinstance(method, str): + raise UserConfigValidationException( + "Got unexpected type {0} for method. " + "Expected string type.".format(type(method)) + ) if method not in CohortFilterMethods.ALL: raise UserConfigValidationException( "Got unexpected value {0} for method. " @@ -140,39 +183,290 @@ def _validate_cohort_filter_parameters( "cohort method {0}.".format( CohortFilterMethods.METHOD_RANGE) ) - if ((not isinstance(arg[0], int) and - not isinstance(arg[0], float)) or - (not isinstance(arg[1], int) and - not isinstance(arg[1], float))): + if (not all(isinstance(entry, int) for entry in arg) and + not all(isinstance(entry, float) for entry in arg)): raise UserConfigValidationException( "Expected int or float type for arg " "with cohort method {0}.".format( CohortFilterMethods.METHOD_RANGE) ) + def _validate_with_test_data(self, test_data: pd.DataFrame, + target_column: str, + categorical_features: List[str], + is_classification: Optional[bool] = True): + """ + Validate the cohort filters parameters with respect to test data. + + :param test_data: Test data over which cohort analysis will be done + in ResponsibleAI Dashboard. + :type test_data: pd.DataFrame + :param target_column: The target column in the test data. + :type target_column: str + :param categorical_features: The categorical feature names. + :type categorical_features: list[str] + :param is_classification: True to indicate if this validation needs + to be done for a classification scenario and False to indicate + that this needs to be done for regression scenario. + :type is_classification: bool + + The following validations need to be performed for cohort filter with + test data:- + + High level validations + 1. Validate if the filter column is present in the test data. + 2. Validate if the filter column is present in the special column + list. + + "Index" Filter validations + 1. The Index filter only takes integer arguments. + 2. The Index filter doesn't take CohortFilterMethods.EXCLUDES + filter method. + + "Classification outcome" Filter validations + 1. Validate that "Classification outcome" filter is not configure for + multiclass classification and regression. + 2. The "Classification outcome" filter only contains values from set + ClassificationOutcomes. + 3. The "Classification outcome" filter only takes + CohortFilterMethods.INCLUDES filter method. + + "Error" Filter validations + 1. Validate that "Error" filter is not configure for + multiclass classification and binary classification. + 2. Only integer or floating points can be configured as arguments. + 3. The CohortFilterMethods.INCLUDES and CohortFilterMethods.EXCLUDES + filter methods cannot be configured for this filter. + + "Predicted Y/True Y" Filter validations + 1. The set of classes configured in case of classification is a + superset of the classes available in the test data. + 2. The CohortFilterMethods.INCLUDES is only allowed to be + configured for "Predicted Y" filter in case of classification. + 3. The CohortFilterMethods.INCLUDES and CohortFilterMethods.EXCLUDES + filter methods cannot be configured for this filter for regression. + + "Dataset" Filter validations + 1. TODO:- For continuous features the allowed values that be configured + should be within the range of minimum and maximum values available + within the continuous feature column in the test data. + 2. For categorical features only CohortFilterMethods.INCLUDES can be + configured. + 3. For categorical features the values allowed are a subset of the + the values available in the categorical column in the test data. + """ + # High level validations + if self.column not in CohortFilter.SPECIAL_COLUMN_LIST and \ + (self.column not in + (set(test_data.columns) - set([target_column]))): + raise UserConfigValidationException( + "Unknown column {0} specified in cohort filter".format( + self.column) + ) + + if self.column == CohortFilter.INDEX: + # "Index" Filter validations + if self.method == CohortFilterMethods.METHOD_EXCLUDES: + raise UserConfigValidationException( + "{0} filter is not supported with {1} based " + "selection.".format( + CohortFilterMethods.METHOD_EXCLUDES, + CohortFilter.INDEX) + ) + + if not all(isinstance(entry, int) for entry in self.arg): + raise UserConfigValidationException( + "All entries in arg should be of type int." + ) + elif self.column == CohortFilter.CLASSIFICATION_OUTCOME: + # "Classification outcome" Filter validations + is_multiclass = len(np.unique( + test_data[target_column].values).tolist()) > 2 + + if not is_classification or is_multiclass: + raise UserConfigValidationException( + "{0} cannot be configured for multi-class classification" + " and regression scenarios.".format( + CohortFilter.CLASSIFICATION_OUTCOME) + ) + + if self.method != CohortFilterMethods.METHOD_INCLUDES: + raise UserConfigValidationException( + "{0} can only be configured with " + "cohort filter {1}.".format( + CohortFilter.CLASSIFICATION_OUTCOME, + CohortFilterMethods.METHOD_INCLUDES) + ) + + for classification_outcome in self.arg: + if classification_outcome not in ClassificationOutcomes.ALL: + raise UserConfigValidationException( + "{0} can only take argument values from {1}.".format( + CohortFilter.CLASSIFICATION_OUTCOME, + " or ".join(ClassificationOutcomes.ALL)) + ) + elif self.column == CohortFilter.REGRESSION_ERROR: + # "Error" Filter validations + if is_classification: + raise UserConfigValidationException( + "{0} cannot be configured for classification" + " scenarios.".format(CohortFilter.REGRESSION_ERROR) + ) + + if self.method == CohortFilterMethods.METHOD_INCLUDES or \ + self.method == CohortFilterMethods.METHOD_EXCLUDES: + raise UserConfigValidationException( + "{0} cannot be configured with either {1} or {2}.".format( + CohortFilter.REGRESSION_ERROR, + CohortFilterMethods.METHOD_INCLUDES, + CohortFilterMethods.METHOD_EXCLUDES + ) + ) + + if not all(isinstance(entry, int) for entry in self.arg) and \ + not all(isinstance(entry, float) for entry in self.arg): + raise UserConfigValidationException( + "All entries in arg should be of type int or float" + " for {} cohort.".format(CohortFilter.REGRESSION_ERROR) + ) + elif self.column == CohortFilter.PREDICTED_Y or \ + self.column == CohortFilter.TRUE_Y: + # "Predicted Y/True Y" Filter validations + if is_classification: + if self.method != CohortFilterMethods.METHOD_INCLUDES: + raise UserConfigValidationException( + "{0} can only be configured with " + "filter {1} for classification".format( + self.column, + CohortFilterMethods.METHOD_INCLUDES) + ) + + test_classes = np.unique( + test_data[target_column].values).tolist() + + if not all(entry in test_classes for entry in self.arg): + raise UserConfigValidationException( + "Found a class in arg which is not present in " + "test data") + else: + if self.method == CohortFilterMethods.METHOD_INCLUDES or \ + self.method == CohortFilterMethods.METHOD_EXCLUDES: + raise UserConfigValidationException( + "{0} cannot be configured with " + "filter {1} for regression.".format( + self.column, self.method) + ) + else: + # "Dataset" Filter validations + if self.column in categorical_features: + if self.method != CohortFilterMethods.METHOD_INCLUDES: + raise UserConfigValidationException( + "{0} is a categorical feature and should be only " + "configured with {1} cohort filter.".format( + self.column, + CohortFilterMethods.METHOD_INCLUDES) + ) + + categories = np.unique( + test_data[self.column].values).tolist() + + for entry in self.arg: + if entry not in categories: + raise UserConfigValidationException( + "Found a category {0} in arg which is not present " + "in test data column {1}.".format( + entry, self.column) + ) + class Cohort: """Defines the cohort which will be injected from SDK into the Dashboard. - :param name: Name of the cohort. :type name: str """ def __init__(self, name: str): """Defines the cohort which will be injected from SDK into the Dashboard. - :param name: Name of the cohort. :type name: str """ + if not isinstance(name, str): + raise UserConfigValidationException( + "Got unexpected type {0} for cohort name. " + "Expected string type.".format(type(name)) + ) self.name = name self.cohort_filter_list = None def add_cohort_filter(self, cohort_filter: CohortFilter): """Add a cohort filter into the cohort. - :param cohort_filter: Cohort filter defined by CohortFilter class. :type: CohortFilter """ + if not isinstance(cohort_filter, CohortFilter): + raise UserConfigValidationException( + "Got unexpected type {0} for cohort filter. " + "Expected CohortFilter type.".format(type(cohort_filter)) + ) if self.cohort_filter_list is None: self.cohort_filter_list = [cohort_filter] else: self.cohort_filter_list.append(cohort_filter) + + def _validate_with_test_data(self, test_data: pd.DataFrame, + target_column: str, + categorical_features: List[str], + is_classification: Optional[bool] = True): + """ + Validate the cohort and cohort filters parameters with respect to + test data. + + :param test_data: Test data over which cohort analysis will be done + in ResponsibleAI Dashboard. + :type test_data: pd.DataFrame + :param target_column: The target column in the test data. + :type target_column: str + :param categorical_features: The categorical feature names. + :type categorical_features: list[str] + :param is_classification: True to indicate if this validation needs + to be done for a classification scenario and False to indicate + that this needs to be done for regression scenario. + :type is_classification: bool + """ + if self.cohort_filter_list is None: + return + if not isinstance(test_data, pd.DataFrame): + raise UserConfigValidationException( + "The test_data should be a pandas DataFrame.") + if not isinstance(target_column, str): + raise UserConfigValidationException( + "The target_column should be string.") + if not isinstance(categorical_features, list): + raise UserConfigValidationException( + "Expected a list type for categorical columns.") + for categorical_feature in categorical_features: + if not isinstance(categorical_feature, str): + raise UserConfigValidationException( + "Feature {0} in categorical_features need to be of " + "string type.".format(categorical_feature) + ) + + if target_column not in test_data.columns: + raise UserConfigValidationException( + "The target_column {0} was not found in test_data.".format( + target_column) + ) + + test_data_columns_set = set(test_data.columns) - set([target_column]) + for categorical_feature in categorical_features: + if categorical_feature not in test_data_columns_set: + raise UserConfigValidationException( + "Found categorical feature {0} which is not" + " present in test data.".format(categorical_feature) + ) + + for cohort_filter in self.cohort_filter_list: + cohort_filter._validate_with_test_data( + test_data=test_data, + target_column=target_column, + categorical_features=categorical_features, + is_classification=is_classification) diff --git a/raiwidgets/tests/test_cohort.py b/raiwidgets/tests/test_cohort.py index 5300715a8e..8bae46f415 100644 --- a/raiwidgets/tests/test_cohort.py +++ b/raiwidgets/tests/test_cohort.py @@ -3,15 +3,41 @@ import json +import pandas as pd import pytest -from raiwidgets._cohort import (Cohort, CohortFilter, CohortFilterMethods, +from raiwidgets._cohort import (ClassificationOutcomes, Cohort, CohortFilter, + CohortFilterMethods, cohort_filter_json_converter) from responsibleai.exceptions import UserConfigValidationException +def get_toy_binary_classification_dataset(): + return pd.DataFrame(data=[[23, 'X'], [25, 'Y']], + columns=["age", "target"]) + + +def get_toy_multiclass_classification_dataset(): + return pd.DataFrame( + data=[[23, 'X'], [25, 'Y'], [25, 'Z']], + columns=["age", "target"]) + + +def get_toy_regression_dataset(): + return pd.DataFrame( + data=[[23, 2.5], [25, 3.6], [25, 4.6]], + columns=["age", "target"]) + + class TestCohortFilter: def test_cohort_filter_validate_method(self): + with pytest.raises( + UserConfigValidationException, + match="Got unexpected type for method. " + "Expected string type."): + CohortFilter(method=1, + arg=[], column=1) + with pytest.raises( UserConfigValidationException, match="Got unexpected value random for method. " @@ -94,9 +120,8 @@ def test_cohort_filter_validate_in_range_methods_type_arg_entries( @pytest.mark.parametrize('method', CohortFilterMethods.SINGLE_VALUE_METHODS) def test_cohort_filter_serialization_single_value_methods(self, method): - cohort_filter_1 = \ - CohortFilter(method=method, - arg=[65.0], column='age') + cohort_filter_1 = CohortFilter(method=method, + arg=[65.0], column='age') json_str = json.dumps(cohort_filter_1, default=cohort_filter_json_converter) assert method in json_str @@ -104,9 +129,9 @@ def test_cohort_filter_serialization_single_value_methods(self, method): assert 'age' in json_str def test_cohort_filter_serialization_in_range_method(self): - cohort_filter_1 = \ - CohortFilter(method=CohortFilterMethods.METHOD_RANGE, - arg=[65.0, 70.0], column='age') + cohort_filter_1 = CohortFilter( + method=CohortFilterMethods.METHOD_RANGE, + arg=[65.0, 70.0], column='age') json_str = json.dumps(cohort_filter_1, default=cohort_filter_json_converter) assert CohortFilterMethods.METHOD_RANGE in json_str @@ -118,9 +143,9 @@ def test_cohort_filter_serialization_in_range_method(self): [CohortFilterMethods.METHOD_INCLUDES, CohortFilterMethods.METHOD_EXCLUDES]) def test_cohort_filter_serialization_include_exclude_methods(self, method): - cohort_filter_str = \ - CohortFilter(method=method, - arg=['val1', 'val2', 'val3'], column='age') + cohort_filter_str = CohortFilter(method=method, + arg=['val1', 'val2', 'val3'], + column='age') json_str = json.dumps(cohort_filter_str, default=cohort_filter_json_converter) assert method in json_str @@ -129,9 +154,9 @@ def test_cohort_filter_serialization_include_exclude_methods(self, method): assert 'val3' in json_str assert 'age' in json_str - cohort_filter_int = \ - CohortFilter(method=method, - arg=[1, 2, 3], column='age') + cohort_filter_int = CohortFilter(method=method, + arg=[1, 2, 3], + column='age') json_str = json.dumps(cohort_filter_int, default=cohort_filter_json_converter) assert method in json_str @@ -141,13 +166,335 @@ def test_cohort_filter_serialization_include_exclude_methods(self, method): assert 'age' in json_str +class TestCohortFilterDataValidations: + def test_validate_with_test_data_high_level_validations(self): + test_data = get_toy_binary_classification_dataset() + + cohort_filter_not_a_feature = CohortFilter( + method=CohortFilterMethods.METHOD_LESS, + arg=[65], column='fake_column') + + with pytest.raises( + UserConfigValidationException, + match="Unknown column fake_column specified in cohort filter"): + cohort_filter_not_a_feature._validate_with_test_data( + test_data=test_data, target_column="target", + categorical_features=[]) + + def test_validate_with_test_data_index_filter_validations(self): + test_data = get_toy_binary_classification_dataset() + + cohort_filter_index_excludes = CohortFilter( + method=CohortFilterMethods.METHOD_EXCLUDES, + arg=[65], column=CohortFilter.INDEX) + with pytest.raises( + UserConfigValidationException, + match="excludes filter is not supported with Index based " + "selection."): + cohort_filter_index_excludes._validate_with_test_data( + test_data=test_data, target_column="target", + categorical_features=[] + ) + + cohort_filter_index_incorrect_args = CohortFilter( + method=CohortFilterMethods.METHOD_GREATER, + arg=[65.0], column=CohortFilter.INDEX) + with pytest.raises( + UserConfigValidationException, + match="All entries in arg should be of type int."): + cohort_filter_index_incorrect_args._validate_with_test_data( + test_data=test_data, target_column="target", + categorical_features=[] + ) + + def test_validate_with_test_data_classification_error_filter_validations( + self): + test_data_multiclass = get_toy_multiclass_classification_dataset() + + test_data_binary = get_toy_binary_classification_dataset() + + cohort_filter_classification_excludes = CohortFilter( + method=CohortFilterMethods.METHOD_EXCLUDES, + arg=[ClassificationOutcomes.FALSE_NEGATIVE], + column=CohortFilter.CLASSIFICATION_OUTCOME) + + cohort_filter_classification_includes = CohortFilter( + method=CohortFilterMethods.METHOD_INCLUDES, + arg=["random"], + column=CohortFilter.CLASSIFICATION_OUTCOME) + + with pytest.raises( + UserConfigValidationException, + match="Classification outcome cannot be " + "configured for multi-class classification" + " and regression scenarios."): + cohort_filter_classification_excludes._validate_with_test_data( + test_data=test_data_multiclass, target_column="target", + categorical_features=[], is_classification=True + ) + + with pytest.raises( + UserConfigValidationException, + match="Classification outcome cannot be " + "configured for multi-class classification" + " and regression scenarios."): + cohort_filter_classification_excludes._validate_with_test_data( + test_data=test_data_binary, target_column="target", + categorical_features=[], is_classification=False + ) + + with pytest.raises( + UserConfigValidationException, + match="Classification outcome can only be configured with " + "cohort filter includes."): + cohort_filter_classification_excludes._validate_with_test_data( + test_data=test_data_binary, target_column="target", + categorical_features=[], is_classification=True + ) + + with pytest.raises( + UserConfigValidationException, + match="Classification outcome can only take argument values " + "from False negative or False positive or True " + "negative or True positive."): + cohort_filter_classification_includes._validate_with_test_data( + test_data=test_data_binary, target_column="target", + categorical_features=[], is_classification=True) + + def test_validate_with_test_data_regression_error_filter_validations( + self): + test_data_regression = get_toy_regression_dataset() + + cohort_filter_regression = CohortFilter( + method=CohortFilterMethods.METHOD_LESS, + arg=[2.5], + column=CohortFilter.REGRESSION_ERROR) + + with pytest.raises( + UserConfigValidationException, + match="Error cannot be configured for classification" + " scenarios."): + cohort_filter_regression._validate_with_test_data( + test_data=test_data_regression, + target_column="target", + categorical_features=[], + is_classification=True) + + with pytest.raises( + UserConfigValidationException, + match="Error cannot be configured with either includes" + " or excludes."): + cohort_filter_regression.method = \ + CohortFilterMethods.METHOD_INCLUDES + cohort_filter_regression._validate_with_test_data( + test_data=test_data_regression, + target_column="target", + categorical_features=[], + is_classification=False) + + with pytest.raises( + UserConfigValidationException, + match="Error cannot be configured with either includes" + " or excludes."): + cohort_filter_regression.method = \ + CohortFilterMethods.METHOD_EXCLUDES + cohort_filter_regression._validate_with_test_data( + test_data=test_data_regression, + target_column="target", + categorical_features=[], + is_classification=False) + + with pytest.raises( + UserConfigValidationException, + match="All entries in arg should be of type int or float" + " for Error cohort."): + cohort_filter_regression.method = \ + CohortFilterMethods.METHOD_GREATER + cohort_filter_regression.arg = ['val1', 'val2'] + cohort_filter_regression._validate_with_test_data( + test_data=test_data_regression, + target_column="target", + categorical_features=[], + is_classification=False) + + @pytest.mark.parametrize('target_filter_type', + [CohortFilter.PREDICTED_Y, + CohortFilter.TRUE_Y]) + @pytest.mark.parametrize('method', + [CohortFilterMethods.METHOD_INCLUDES, + CohortFilterMethods.METHOD_EXCLUDES]) + def test_validate_with_test_data_regression_target_filter_validations( + self, target_filter_type, method): + test_data_regression = get_toy_regression_dataset() + + with pytest.raises( + UserConfigValidationException, + match="{0} cannot be configured with " + "filter {1} for regression.".format(target_filter_type, + method)): + cohort_filter_regression = CohortFilter( + method=method, + arg=[2.5], + column=target_filter_type) + cohort_filter_regression._validate_with_test_data( + test_data=test_data_regression, + target_column="target", + categorical_features=[], + is_classification=False) + + @pytest.mark.parametrize('target_filter_type', + [CohortFilter.PREDICTED_Y, + CohortFilter.TRUE_Y]) + def test_validate_with_test_data_classification_target_filter_validations( + self, target_filter_type): + test_data_classification = get_toy_binary_classification_dataset() + + with pytest.raises( + UserConfigValidationException, + match="{0} can only be configured with " + "filter {1} for classification".format( + target_filter_type, + CohortFilterMethods.METHOD_INCLUDES)): + cohort_filter_classification = CohortFilter( + method=CohortFilterMethods.METHOD_EXCLUDES, + arg=['X'], + column=target_filter_type) + cohort_filter_classification._validate_with_test_data( + test_data=test_data_classification, + target_column="target", + categorical_features=[], + is_classification=True) + + with pytest.raises( + UserConfigValidationException, + match="Found a class in arg which is not present in " + "test data"): + cohort_filter_classification = CohortFilter( + method=CohortFilterMethods.METHOD_INCLUDES, + arg=['Z'], + column=target_filter_type) + cohort_filter_classification._validate_with_test_data( + test_data=test_data_classification, + target_column="target", + categorical_features=[], + is_classification=True) + + def test_validate_with_test_data_with_dataset_validations( + self): + test_data = pd.DataFrame( + data=[[23, 'new', 'A'], [25, 'new, ''B'], [25, 'old', 'B']], + columns=["age", 'type', "target"]) + + with pytest.raises( + UserConfigValidationException, + match="{0} is a categorical feature and should be only " + "configured with {1} cohort filter.".format( + "type", + CohortFilterMethods.METHOD_INCLUDES)): + cohort_filter = CohortFilter( + method=CohortFilterMethods.METHOD_EXCLUDES, + arg=['new'], + column='type') + cohort_filter._validate_with_test_data( + test_data=test_data, + target_column="target", + categorical_features=['type'], + is_classification=True) + + with pytest.raises( + UserConfigValidationException, + match="Found a category {0} in arg which is not present " + "in test data column {1}.".format('mid', 'type')): + cohort_filter = CohortFilter( + method=CohortFilterMethods.METHOD_INCLUDES, + arg=['mid'], + column='type') + cohort_filter._validate_with_test_data( + test_data=test_data, + target_column="target", + categorical_features=['type'], + is_classification=True) + + class TestCohort: + def test_cohort_configuration_validations(self): + with pytest.raises( + UserConfigValidationException, + match="Got unexpected type for cohort name. " + "Expected string type."): + Cohort(name=1) + + with pytest.raises( + UserConfigValidationException, + match="Got unexpected type for cohort filter. " + "Expected CohortFilter type"): + cohort = Cohort(name="Cohort New") + cohort.add_cohort_filter(cohort_filter=[]) + + def test_cohort_validate_with_test_data(self): + cohort_filter_1 = CohortFilter( + method=CohortFilterMethods.METHOD_LESS, + arg=[65], column='age') + cohort_1 = Cohort(name="Cohort New") + cohort_1.add_cohort_filter(cohort_filter_1) + test_data = get_toy_binary_classification_dataset() + + with pytest.raises( + UserConfigValidationException, + match="The test_data should be a pandas DataFrame"): + cohort_1._validate_with_test_data( + test_data=[], target_column='target', + categorical_features=[]) + + with pytest.raises( + UserConfigValidationException, + match="The target_column should be string."): + cohort_1._validate_with_test_data( + test_data=test_data, + target_column=1, + categorical_features=[]) + + with pytest.raises( + UserConfigValidationException, + match="The target_column fake_target " + "was not found in test_data."): + cohort_1._validate_with_test_data( + test_data=test_data, + target_column="fake_target", + categorical_features=[]) + + with pytest.raises( + UserConfigValidationException, + match="Expected a list type for " + "categorical columns."): + cohort_1._validate_with_test_data( + test_data=test_data, + target_column="target", + categorical_features={}) + + with pytest.raises( + UserConfigValidationException, + match="Feature 1 in categorical_features need to be of " + "string type."): + cohort_1._validate_with_test_data( + test_data=test_data, + target_column="target", + categorical_features=[1, 2]) + + with pytest.raises( + UserConfigValidationException, + match="Found categorical feature hours-per-week which is not" + " present in test data."): + cohort_1._validate_with_test_data( + test_data=test_data, + target_column="target", + categorical_features=["hours-per-week"]) + @pytest.mark.parametrize('method', CohortFilterMethods.SINGLE_VALUE_METHODS) def test_cohort_serialization_single_value_method(self, method): - cohort_filter_1 = \ - CohortFilter(method=method, - arg=[65], column='age') + cohort_filter_1 = CohortFilter(method=method, + arg=[65], column='age') cohort_1 = Cohort(name="Cohort New") cohort_1.add_cohort_filter(cohort_filter_1) json_str = json.dumps(cohort_1, @@ -159,9 +506,9 @@ def test_cohort_serialization_single_value_method(self, method): assert 'age' in json_str def test_cohort_serialization_in_range_method(self): - cohort_filter_1 = \ - CohortFilter(method=CohortFilterMethods.METHOD_RANGE, - arg=[65.0, 70.0], column='age') + cohort_filter_1 = CohortFilter( + method=CohortFilterMethods.METHOD_RANGE, + arg=[65.0, 70.0], column='age') cohort_1 = Cohort(name="Cohort New") cohort_1.add_cohort_filter(cohort_filter_1) json_str = json.dumps(cohort_1, @@ -177,9 +524,9 @@ def test_cohort_serialization_in_range_method(self): [CohortFilterMethods.METHOD_INCLUDES, CohortFilterMethods.METHOD_EXCLUDES]) def test_cohort_serialization_include_exclude_methods(self, method): - cohort_filter_str = \ - CohortFilter(method=method, - arg=['val1', 'val2', 'val3'], column='age') + cohort_filter_str = CohortFilter(method=method, + arg=['val1', 'val2', 'val3'], + column='age') cohort_str = Cohort(name="Cohort New Str") cohort_str.add_cohort_filter(cohort_filter_str) json_str = json.dumps(cohort_str, @@ -190,9 +537,9 @@ def test_cohort_serialization_include_exclude_methods(self, method): assert 'val3' in json_str assert 'age' in json_str - cohort_filter_int = \ - CohortFilter(method=method, - arg=[1, 2, 3], column='age') + cohort_filter_int = CohortFilter(method=method, + arg=[1, 2, 3], + column='age') cohort_int = Cohort(name="Cohort New Int") cohort_int.add_cohort_filter(cohort_filter_int) json_str = json.dumps(cohort_filter_int, @@ -206,9 +553,9 @@ def test_cohort_serialization_include_exclude_methods(self, method): class TestCohortList: def test_cohort_list_serialization(self): - cohort_filter_1 = \ - CohortFilter(method=CohortFilterMethods.METHOD_LESS, - arg=[65], column='age') + cohort_filter_1 = CohortFilter( + method=CohortFilterMethods.METHOD_LESS, + arg=[65], column='age') cohort_1 = Cohort(name="Cohort New") cohort_1.add_cohort_filter(cohort_filter_1)