diff --git a/responsibleai/responsibleai/managers/causal_manager.py b/responsibleai/responsibleai/managers/causal_manager.py index 5c1c9a5c6b..9600ebb96f 100644 --- a/responsibleai/responsibleai/managers/causal_manager.py +++ b/responsibleai/responsibleai/managers/causal_manager.py @@ -219,6 +219,36 @@ def add( f"got {heterogeneity_model}") raise UserConfigValidationException(message) + # Check treatment_cost is valid + if isinstance(treatment_cost, int) and \ + treatment_cost == 0: + revised_treatment_cost = [0] * len( + treatment_features) + else: + revised_treatment_cost = treatment_cost + + if not isinstance(revised_treatment_cost, list): + message = ( + "treatment_cost must be a list with " + "the same number of elements as " + "treatment_features where each element " + "is either a constant cost of treatment " + "or an array specifying the cost of " + "treatment per sample. " + "Found treatment_cost of type " + f"{type(revised_treatment_cost)}, expected list.") + raise UserConfigValidationException(message) + elif len(revised_treatment_cost) != \ + len(treatment_features): + message = ( + "treatment_cost must be a list with " + "the same number of elements as " + "treatment_features. " + "Length of treatment_cost was " + f"{len(revised_treatment_cost)}, expected " + f"{len(treatment_features)}.") + raise UserConfigValidationException(message) + validate_train_test_categories( train_data=self._train, test_data=self._test, @@ -233,7 +263,7 @@ def add( heterogeneity_model=heterogeneity_model, alpha=alpha, upper_bound_on_cat_expansion=upper_bound_on_cat_expansion, - treatment_cost=treatment_cost, + treatment_cost=revised_treatment_cost, min_tree_leaf_samples=min_tree_leaf_samples, max_tree_depth=max_tree_depth, skip_cat_limit_checks=skip_cat_limit_checks, @@ -405,41 +435,11 @@ def compute(self): X_test, alpha=causal_config.alpha, keep_all_levels=True) result.policies = [] - - # Check treatment_cost is valid - if isinstance(causal_config.treatment_cost, int) and \ - causal_config.treatment_cost == 0: - revised_treatment_cost = [0] * len( - causal_config.treatment_features) - else: - revised_treatment_cost = causal_config.treatment_cost - - if not isinstance(revised_treatment_cost, list): - message = ( - "treatment_cost must be a list with " - "the same number of elements as " - "treatment_features where each element " - "is either a constant cost of treatment " - "or an array specifying the cost of " - "treatment per sample. " - "Found treatment_cost of type " - f"{type(revised_treatment_cost)}, expected list.") - raise UserConfigValidationException(message) - elif len(revised_treatment_cost) != \ - len(causal_config.treatment_features): - message = ("treatment_cost must be a list with " - "the same number of elements as " - "treatment_features. " - "Length of treatment_cost was " - f"{len(revised_treatment_cost)}, expected " - f"{len(causal_config.treatment_features)}.") - raise UserConfigValidationException(message) - for i in range(len(causal_config.treatment_features)): policy = self._create_policy( result, X_test, causal_config.treatment_features[i], - revised_treatment_cost[i], + causal_config.treatment_cost[i], causal_config.alpha, causal_config.max_tree_depth, causal_config.min_tree_leaf_samples) result.policies.append(policy) diff --git a/responsibleai/tests/causal/test_causal_manager.py b/responsibleai/tests/causal/test_causal_manager.py index ab63a7309f..6331be2040 100644 --- a/responsibleai/tests/causal/test_causal_manager.py +++ b/responsibleai/tests/causal/test_causal_manager.py @@ -251,7 +251,6 @@ def test_zero_cost(self, cost_manager): cost_manager.add(['AveRooms', 'Population', 'AveOccup'], treatment_cost=0) cost_manager.compute() - except TypeError: pass mock_create.assert_any_call(ANY, ANY, 'AveRooms', 0, ANY, ANY, ANY) @@ -269,7 +268,6 @@ def test_nonzero_scalar_cost(self, cost_manager): with pytest.raises(UserConfigValidationException, match=message): cost_manager.add(['AveRooms', 'Population', 'AveOccup'], treatment_cost=5) - cost_manager.compute() def test_nonlist_cost(self, cost_manager): message = ("treatment_cost must be a list with the same number of " @@ -281,7 +279,6 @@ def test_nonlist_cost(self, cost_manager): with pytest.raises(UserConfigValidationException, match=message): cost_manager.add(['AveRooms', 'Population', 'AveOccup'], treatment_cost=np.array([1, 2])) - cost_manager.compute() def test_invalid_cost_list_length(self, cost_manager): expected = ("treatment_cost must be a list with the same number of " @@ -290,7 +287,6 @@ def test_invalid_cost_list_length(self, cost_manager): with pytest.raises(UserConfigValidationException, match=expected): cost_manager.add(['AveRooms', 'Population', 'AveOccup'], treatment_cost=[1, 2]) - cost_manager.compute() def test_constant_cost_per_treatment_feature(self, cost_manager): with patch.object(cost_manager, '_create_policy', return_value=None)\