-
Notifications
You must be signed in to change notification settings - Fork 360
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Upper bound SciKit-Learn to address freeze in causal (#1432)
<!--- Provide a general summary of your changes in the Title above --> ## Description Replaces #1429 to address #1430 . Causal analysis is getting stuck with the latest release of SciKit-Learn. This contains: - Test case which gets stuck with SciKit-Learn 1.1.0 - Upper bound on SciKit-Learn in `requirements.txt` ## Checklist <!--- Make sure to satisfy all the criteria listed below. --> - [x] I have added screenshots above for all UI changes. - [x] Documentation was updated if it was needed. - [x] New tests were added or changes were manually verified. Signed-off-by: Richard Edgar <[email protected]>
- Loading branch information
1 parent
712b54d
commit 016afd9
Showing
2 changed files
with
52 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) Microsoft Corporation | ||
# Licensed under the MIT License. | ||
|
||
|
||
from responsibleai import RAIInsights | ||
|
||
from ..causal_manager_validator import _check_causal_result | ||
from ..common_utils import create_adult_income_dataset | ||
|
||
|
||
def test_causal_classification_scikitlearn_issue(): | ||
# This test gets stuck on SciKit-Learn v1.1.0 | ||
# See PR #1429 | ||
data_train, data_test, _, _, categorical_features, \ | ||
_, target_name, classes = create_adult_income_dataset() | ||
|
||
rai_i = RAIInsights( | ||
model=None, | ||
train=data_train, | ||
test=data_test, | ||
task_type='classification', | ||
target_column=target_name, | ||
categorical_features=categorical_features, | ||
classes=classes | ||
) | ||
assert rai_i is not None | ||
|
||
treatment_features = ["age", "gender"] | ||
cat_expansion = 49 | ||
rai_i.causal.add(treatment_features=treatment_features, | ||
heterogeneity_features=["marital_status"], | ||
nuisance_model="automl", | ||
heterogeneity_model="forest", | ||
alpha=0.06, | ||
upper_bound_on_cat_expansion=cat_expansion, | ||
treatment_cost=[0.1, 0.2], | ||
min_tree_leaf_samples=2, | ||
skip_cat_limit_checks=False, | ||
categories="auto", | ||
n_jobs=1, | ||
verbose=1, | ||
random_state=100, | ||
) | ||
|
||
rai_i.compute() | ||
|
||
results = rai_i.causal.get() | ||
assert results is not None | ||
assert isinstance(results, list) | ||
assert len(results) == 1 | ||
_check_causal_result(results[0]) |