From 18d38dd0b1c1a6b1f9a9cd906ae22a5e34ed5c42 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Thu, 10 Feb 2022 20:02:17 -0800 Subject: [PATCH] Add heterogeneity_model checks (#1210) * Add hetrogenity_model checks Signed-off-by: Gaurav Gupta * Fix lint Signed-off-by: Gaurav Gupta --- .../responsibleai/_tools/causal/causal_constants.py | 1 + responsibleai/responsibleai/managers/causal_manager.py | 8 ++++++++ responsibleai/tests/causal_manager_validator.py | 7 ++++++- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/responsibleai/responsibleai/_tools/causal/causal_constants.py b/responsibleai/responsibleai/_tools/causal/causal_constants.py index 665987c5a1..b22c32c7ed 100644 --- a/responsibleai/responsibleai/_tools/causal/causal_constants.py +++ b/responsibleai/responsibleai/_tools/causal/causal_constants.py @@ -8,6 +8,7 @@ class ModelTypes: """Model type constants.""" AUTOML = 'automl' LINEAR = 'linear' + FOREST = 'forest' class DefaultParams: diff --git a/responsibleai/responsibleai/managers/causal_manager.py b/responsibleai/responsibleai/managers/causal_manager.py index b3581ec559..6255f6f1c0 100644 --- a/responsibleai/responsibleai/managers/causal_manager.py +++ b/responsibleai/responsibleai/managers/causal_manager.py @@ -172,6 +172,14 @@ def add( f"got {nuisance_model}") raise UserConfigValidationException(message) + if heterogeneity_model not in [ModelTypes.FOREST, + ModelTypes.LINEAR]: + message = (f"heterogeneity_model should be one of " + f"['{ModelTypes.FOREST}', " + f"'{ModelTypes.LINEAR}'], " + f"got {heterogeneity_model}") + raise UserConfigValidationException(message) + validate_train_test_categories( train_data=self._train, test_data=self._test, diff --git a/responsibleai/tests/causal_manager_validator.py b/responsibleai/tests/causal_manager_validator.py index 0a3f02feee..ae7e622b45 100644 --- a/responsibleai/tests/causal_manager_validator.py +++ b/responsibleai/tests/causal_manager_validator.py @@ -83,11 +83,16 @@ def validate_causal(rai_insights, data, target_column, _check_causal_properties(rai_insights.causal.list(), expected_causal_effects=2) - # Add a bad configuration + # Add a bad configuration for nuisance_model with pytest.raises(UserConfigValidationException): rai_insights.causal.add(treatment_features, nuisance_model='fake_model') + # Add a bad configuration for heterogeneity_model + with pytest.raises(UserConfigValidationException): + rai_insights.causal.add(treatment_features, + heterogeneity_model='fake_model') + def _check_causal_properties( causal_props, expected_causal_effects):