Skip to content

Commit

Permalink
causal_manager.py: Move user config validation checks from compute() …
Browse files Browse the repository at this point in the history
…to add() (#2342)

Signed-off-by: Gaurav Gupta <[email protected]>
  • Loading branch information
gaugup authored Sep 19, 2023
1 parent 0a45661 commit e3cbcfd
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 36 deletions.
64 changes: 32 additions & 32 deletions responsibleai/responsibleai/managers/causal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,36 @@ def add(
f"got {heterogeneity_model}")
raise UserConfigValidationException(message)

# Check treatment_cost is valid
if isinstance(treatment_cost, int) and \
treatment_cost == 0:
revised_treatment_cost = [0] * len(
treatment_features)
else:
revised_treatment_cost = treatment_cost

if not isinstance(revised_treatment_cost, list):
message = (
"treatment_cost must be a list with "
"the same number of elements as "
"treatment_features where each element "
"is either a constant cost of treatment "
"or an array specifying the cost of "
"treatment per sample. "
"Found treatment_cost of type "
f"{type(revised_treatment_cost)}, expected list.")
raise UserConfigValidationException(message)
elif len(revised_treatment_cost) != \
len(treatment_features):
message = (
"treatment_cost must be a list with "
"the same number of elements as "
"treatment_features. "
"Length of treatment_cost was "
f"{len(revised_treatment_cost)}, expected "
f"{len(treatment_features)}.")
raise UserConfigValidationException(message)

validate_train_test_categories(
train_data=self._train,
test_data=self._test,
Expand All @@ -233,7 +263,7 @@ def add(
heterogeneity_model=heterogeneity_model,
alpha=alpha,
upper_bound_on_cat_expansion=upper_bound_on_cat_expansion,
treatment_cost=treatment_cost,
treatment_cost=revised_treatment_cost,
min_tree_leaf_samples=min_tree_leaf_samples,
max_tree_depth=max_tree_depth,
skip_cat_limit_checks=skip_cat_limit_checks,
Expand Down Expand Up @@ -405,41 +435,11 @@ def compute(self):
X_test, alpha=causal_config.alpha, keep_all_levels=True)

result.policies = []

# Check treatment_cost is valid
if isinstance(causal_config.treatment_cost, int) and \
causal_config.treatment_cost == 0:
revised_treatment_cost = [0] * len(
causal_config.treatment_features)
else:
revised_treatment_cost = causal_config.treatment_cost

if not isinstance(revised_treatment_cost, list):
message = (
"treatment_cost must be a list with "
"the same number of elements as "
"treatment_features where each element "
"is either a constant cost of treatment "
"or an array specifying the cost of "
"treatment per sample. "
"Found treatment_cost of type "
f"{type(revised_treatment_cost)}, expected list.")
raise UserConfigValidationException(message)
elif len(revised_treatment_cost) != \
len(causal_config.treatment_features):
message = ("treatment_cost must be a list with "
"the same number of elements as "
"treatment_features. "
"Length of treatment_cost was "
f"{len(revised_treatment_cost)}, expected "
f"{len(causal_config.treatment_features)}.")
raise UserConfigValidationException(message)

for i in range(len(causal_config.treatment_features)):
policy = self._create_policy(
result, X_test,
causal_config.treatment_features[i],
revised_treatment_cost[i],
causal_config.treatment_cost[i],
causal_config.alpha, causal_config.max_tree_depth,
causal_config.min_tree_leaf_samples)
result.policies.append(policy)
Expand Down
4 changes: 0 additions & 4 deletions responsibleai/tests/causal/test_causal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def test_zero_cost(self, cost_manager):
cost_manager.add(['AveRooms', 'Population', 'AveOccup'],
treatment_cost=0)
cost_manager.compute()

except TypeError:
pass
mock_create.assert_any_call(ANY, ANY, 'AveRooms', 0, ANY, ANY, ANY)
Expand All @@ -269,7 +268,6 @@ def test_nonzero_scalar_cost(self, cost_manager):
with pytest.raises(UserConfigValidationException, match=message):
cost_manager.add(['AveRooms', 'Population', 'AveOccup'],
treatment_cost=5)
cost_manager.compute()

def test_nonlist_cost(self, cost_manager):
message = ("treatment_cost must be a list with the same number of "
Expand All @@ -281,7 +279,6 @@ def test_nonlist_cost(self, cost_manager):
with pytest.raises(UserConfigValidationException, match=message):
cost_manager.add(['AveRooms', 'Population', 'AveOccup'],
treatment_cost=np.array([1, 2]))
cost_manager.compute()

def test_invalid_cost_list_length(self, cost_manager):
expected = ("treatment_cost must be a list with the same number of "
Expand All @@ -290,7 +287,6 @@ def test_invalid_cost_list_length(self, cost_manager):
with pytest.raises(UserConfigValidationException, match=expected):
cost_manager.add(['AveRooms', 'Population', 'AveOccup'],
treatment_cost=[1, 2])
cost_manager.compute()

def test_constant_cost_per_treatment_feature(self, cost_manager):
with patch.object(cost_manager, '_create_policy', return_value=None)\
Expand Down

0 comments on commit e3cbcfd

Please sign in to comment.