Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data validations to SDK defined cohorts #1227

Merged
merged 3 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 306 additions & 12 deletions raiwidgets/raiwidgets/_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

"""Module for defining cohorts in raiwidgets package."""

from typing import Any, List
from typing import Any, List, Optional

import numpy as np
import pandas as pd

from responsibleai.exceptions import UserConfigValidationException

Expand Down Expand Up @@ -39,9 +42,20 @@ class CohortFilterMethods:
METHOD_EQUAL]


class ClassificationOutcomes:
"""Defines the possible values for classification outcomes.
"""
FALSE_NEGATIVE = 'False negative'
FALSE_POSITIVE = 'False positive'
TRUE_NEGATIVE = 'True negative'
TRUE_POSITIVE = 'True positive'

ALL = [FALSE_NEGATIVE, FALSE_POSITIVE,
TRUE_NEGATIVE, TRUE_POSITIVE]


def cohort_filter_json_converter(obj):
"""Helper function to convert CohortFilter object to json.

:param obj: Object to convert to json.
:type obj: object
:return: The converted json.
Expand All @@ -57,7 +71,6 @@ def cohort_filter_json_converter(obj):

class CohortFilter:
"""Defines the cohort filter.

:param method: Cohort filter method from one of CohortFilterMethods.
:type method: str
:param arg: List of values to be used by the cohort filter.
Expand All @@ -66,9 +79,20 @@ class CohortFilter:
will be applied.
:type column: str
"""
PREDICTED_Y = 'Predicted Y'
TRUE_Y = 'True Y'
INDEX = 'Index'
CLASSIFICATION_OUTCOME = 'Classification outcome'
REGRESSION_ERROR = 'Error'

SPECIAL_COLUMN_LIST = [INDEX,
PREDICTED_Y,
TRUE_Y,
CLASSIFICATION_OUTCOME,
REGRESSION_ERROR]

def __init__(self, method: str, arg: List[Any], column: str):
"""Defines the cohort filter.

:param method: Cohort filter method from one of CohortFilterMethods.
:type method: str
:param arg: List of values to be used by the cohort filter.
Expand All @@ -89,15 +113,34 @@ def __init__(self, method: str, arg: List[Any], column: str):
def _validate_cohort_filter_parameters(
self, method: str, arg: List[Any], column: str):
"""Validate the input values for the cohort filter.

:param method: Cohort filter method from one of CohortFilterMethods.
:type method: str
:param arg: List of values to be used by the cohort filter.
:type arg: list
:param column: The column name from the dataset on which the filter
will be applied.
:type column: str

The following validations can be performed on the cohort filter:-

1. Verify the correct types for method (expected string), column
(expected string) and arg (expected list).
2. The method value should be one of the filter string from
CohortFilterMethods.ALL.
3. The arg shouldn't be an empty list.
4. For all cohort filter methods in
CohortFilterMethods.SINGLE_VALUE_METHODS, the value in the arg
should be integer or float and there should be only one value
in arg.
5. For cohort filter method CohortFilterMethods.METHOD_RANGE,
the values in the arg should be integer or float and there
should be only two values in arg.
"""
if not isinstance(method, str):
raise UserConfigValidationException(
"Got unexpected type {0} for method. "
"Expected string type.".format(type(method))
)
if method not in CohortFilterMethods.ALL:
raise UserConfigValidationException(
"Got unexpected value {0} for method. "
Expand Down Expand Up @@ -140,39 +183,290 @@ def _validate_cohort_filter_parameters(
"cohort method {0}.".format(
CohortFilterMethods.METHOD_RANGE)
)
if ((not isinstance(arg[0], int) and
not isinstance(arg[0], float)) or
(not isinstance(arg[1], int) and
not isinstance(arg[1], float))):
if (not all(isinstance(entry, int) for entry in arg) and
not all(isinstance(entry, float) for entry in arg)):
raise UserConfigValidationException(
"Expected int or float type for arg "
"with cohort method {0}.".format(
CohortFilterMethods.METHOD_RANGE)
)

def _validate_with_test_data(self, test_data: pd.DataFrame,
target_column: str,
categorical_features: List[str],
is_classification: Optional[bool] = True):
"""
Validate the cohort filters parameters with respect to test data.

:param test_data: Test data over which cohort analysis will be done
in ResponsibleAI Dashboard.
:type test_data: pd.DataFrame
:param target_column: The target column in the test data.
:type target_column: str
:param categorical_features: The categorical feature names.
:type categorical_features: list[str]
:param is_classification: True to indicate if this validation needs
to be done for a classification scenario and False to indicate
that this needs to be done for regression scenario.
:type is_classification: bool

The following validations need to be performed for cohort filter with
test data:-

High level validations
1. Validate if the filter column is present in the test data.
2. Validate if the filter column is present in the special column
list.

"Index" Filter validations
1. The Index filter only takes integer arguments.
2. The Index filter doesn't take CohortFilterMethods.EXCLUDES
filter method.

"Classification outcome" Filter validations
1. Validate that "Classification outcome" filter is not configure for
multiclass classification and regression.
2. The "Classification outcome" filter only contains values from set
ClassificationOutcomes.
3. The "Classification outcome" filter only takes
CohortFilterMethods.INCLUDES filter method.

"Error" Filter validations
1. Validate that "Error" filter is not configure for
multiclass classification and binary classification.
2. Only integer or floating points can be configured as arguments.
3. The CohortFilterMethods.INCLUDES and CohortFilterMethods.EXCLUDES
filter methods cannot be configured for this filter.

"Predicted Y/True Y" Filter validations
1. The set of classes configured in case of classification is a
superset of the classes available in the test data.
2. The CohortFilterMethods.INCLUDES is only allowed to be
configured for "Predicted Y" filter in case of classification.
3. The CohortFilterMethods.INCLUDES and CohortFilterMethods.EXCLUDES
filter methods cannot be configured for this filter for regression.

"Dataset" Filter validations
1. TODO:- For continuous features the allowed values that be configured
should be within the range of minimum and maximum values available
within the continuous feature column in the test data.
2. For categorical features only CohortFilterMethods.INCLUDES can be
configured.
3. For categorical features the values allowed are a subset of the
the values available in the categorical column in the test data.
"""
# High level validations
if self.column not in CohortFilter.SPECIAL_COLUMN_LIST and \
(self.column not in
(set(test_data.columns) - set([target_column]))):
raise UserConfigValidationException(
"Unknown column {0} specified in cohort filter".format(
self.column)
)

if self.column == CohortFilter.INDEX:
# "Index" Filter validations
if self.method == CohortFilterMethods.METHOD_EXCLUDES:
raise UserConfigValidationException(
"{0} filter is not supported with {1} based "
"selection.".format(
CohortFilterMethods.METHOD_EXCLUDES,
CohortFilter.INDEX)
)

if not all(isinstance(entry, int) for entry in self.arg):
raise UserConfigValidationException(
"All entries in arg should be of type int."
)
elif self.column == CohortFilter.CLASSIFICATION_OUTCOME:
# "Classification outcome" Filter validations
is_multiclass = len(np.unique(
test_data[target_column].values).tolist()) > 2

if not is_classification or is_multiclass:
raise UserConfigValidationException(
"{0} cannot be configured for multi-class classification"
" and regression scenarios.".format(
CohortFilter.CLASSIFICATION_OUTCOME)
)

if self.method != CohortFilterMethods.METHOD_INCLUDES:
raise UserConfigValidationException(
"{0} can only be configured with "
"cohort filter {1}.".format(
CohortFilter.CLASSIFICATION_OUTCOME,
CohortFilterMethods.METHOD_INCLUDES)
)

for classification_outcome in self.arg:
if classification_outcome not in ClassificationOutcomes.ALL:
raise UserConfigValidationException(
"{0} can only take argument values from {1}.".format(
CohortFilter.CLASSIFICATION_OUTCOME,
" or ".join(ClassificationOutcomes.ALL))
)
elif self.column == CohortFilter.REGRESSION_ERROR:
# "Error" Filter validations
if is_classification:
raise UserConfigValidationException(
"{0} cannot be configured for classification"
" scenarios.".format(CohortFilter.REGRESSION_ERROR)
)

if self.method == CohortFilterMethods.METHOD_INCLUDES or \
self.method == CohortFilterMethods.METHOD_EXCLUDES:
raise UserConfigValidationException(
"{0} cannot be configured with either {1} or {2}.".format(
CohortFilter.REGRESSION_ERROR,
CohortFilterMethods.METHOD_INCLUDES,
CohortFilterMethods.METHOD_EXCLUDES
)
)

if not all(isinstance(entry, int) for entry in self.arg) and \
not all(isinstance(entry, float) for entry in self.arg):
raise UserConfigValidationException(
"All entries in arg should be of type int or float"
" for {} cohort.".format(CohortFilter.REGRESSION_ERROR)
)
elif self.column == CohortFilter.PREDICTED_Y or \
self.column == CohortFilter.TRUE_Y:
# "Predicted Y/True Y" Filter validations
if is_classification:
if self.method != CohortFilterMethods.METHOD_INCLUDES:
raise UserConfigValidationException(
"{0} can only be configured with "
"filter {1} for classification".format(
self.column,
CohortFilterMethods.METHOD_INCLUDES)
)

test_classes = np.unique(
test_data[target_column].values).tolist()

if not all(entry in test_classes for entry in self.arg):
raise UserConfigValidationException(
"Found a class in arg which is not present in "
"test data")
else:
if self.method == CohortFilterMethods.METHOD_INCLUDES or \
self.method == CohortFilterMethods.METHOD_EXCLUDES:
raise UserConfigValidationException(
"{0} cannot be configured with "
"filter {1} for regression.".format(
self.column, self.method)
)
else:
# "Dataset" Filter validations
if self.column in categorical_features:
if self.method != CohortFilterMethods.METHOD_INCLUDES:
raise UserConfigValidationException(
"{0} is a categorical feature and should be only "
"configured with {1} cohort filter.".format(
self.column,
CohortFilterMethods.METHOD_INCLUDES)
)

categories = np.unique(
test_data[self.column].values).tolist()

for entry in self.arg:
if entry not in categories:
raise UserConfigValidationException(
"Found a category {0} in arg which is not present "
"in test data column {1}.".format(
entry, self.column)
)


class Cohort:
"""Defines the cohort which will be injected from SDK into the Dashboard.

:param name: Name of the cohort.
:type name: str
"""
def __init__(self, name: str):
"""Defines the cohort which will be injected from SDK into the Dashboard.

:param name: Name of the cohort.
:type name: str
"""
if not isinstance(name, str):
raise UserConfigValidationException(
"Got unexpected type {0} for cohort name. "
"Expected string type.".format(type(name))
)
self.name = name
self.cohort_filter_list = None

def add_cohort_filter(self, cohort_filter: CohortFilter):
"""Add a cohort filter into the cohort.

:param cohort_filter: Cohort filter defined by CohortFilter class.
:type: CohortFilter
"""
if not isinstance(cohort_filter, CohortFilter):
raise UserConfigValidationException(
"Got unexpected type {0} for cohort filter. "
"Expected CohortFilter type.".format(type(cohort_filter))
)
if self.cohort_filter_list is None:
self.cohort_filter_list = [cohort_filter]
else:
self.cohort_filter_list.append(cohort_filter)

def _validate_with_test_data(self, test_data: pd.DataFrame,
target_column: str,
categorical_features: List[str],
is_classification: Optional[bool] = True):
"""
Validate the cohort and cohort filters parameters with respect to
test data.

:param test_data: Test data over which cohort analysis will be done
in ResponsibleAI Dashboard.
:type test_data: pd.DataFrame
:param target_column: The target column in the test data.
:type target_column: str
:param categorical_features: The categorical feature names.
:type categorical_features: list[str]
:param is_classification: True to indicate if this validation needs
to be done for a classification scenario and False to indicate
that this needs to be done for regression scenario.
:type is_classification: bool
"""
if self.cohort_filter_list is None:
return
if not isinstance(test_data, pd.DataFrame):
raise UserConfigValidationException(
"The test_data should be a pandas DataFrame.")
if not isinstance(target_column, str):
raise UserConfigValidationException(
"The target_column should be string.")
if not isinstance(categorical_features, list):
raise UserConfigValidationException(
"Expected a list type for categorical columns.")
for categorical_feature in categorical_features:
if not isinstance(categorical_feature, str):
raise UserConfigValidationException(
"Feature {0} in categorical_features need to be of "
"string type.".format(categorical_feature)
)

if target_column not in test_data.columns:
raise UserConfigValidationException(
"The target_column {0} was not found in test_data.".format(
target_column)
)

test_data_columns_set = set(test_data.columns) - set([target_column])
for categorical_feature in categorical_features:
if categorical_feature not in test_data_columns_set:
raise UserConfigValidationException(
"Found categorical feature {0} which is not"
" present in test data.".format(categorical_feature)
)

for cohort_filter in self.cohort_filter_list:
cohort_filter._validate_with_test_data(
test_data=test_data,
target_column=target_column,
categorical_features=categorical_features,
is_classification=is_classification)
Loading