Skip to content

Commit

Permalink
Upper bound SciKit-Learn to address freeze in causal (#1432)
Browse files Browse the repository at this point in the history
<!--- Provide a general summary of your changes in the Title above -->

## Description

Replaces #1429  to address #1430 . Causal analysis is getting stuck with the latest release of SciKit-Learn. This contains:

- Test case which gets stuck with SciKit-Learn 1.1.0
- Upper bound on SciKit-Learn in `requirements.txt`

## Checklist

<!--- Make sure to satisfy all the criteria listed below. -->

- [x] I have added screenshots above for all UI changes.
- [x] Documentation was updated if it was needed.
- [x] New tests were added or changes were manually verified.


Signed-off-by: Richard Edgar <[email protected]>
  • Loading branch information
riedgar-ms authored and gaugup committed May 26, 2022
1 parent 712b54d commit 016afd9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
2 changes: 1 addition & 1 deletion responsibleai/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ lightgbm>=2.0.11
numpy>=1.17.2
numba<0.54.0
pandas>=0.25.1
scikit-learn>=0.22.1
scikit-learn>=0.22.1,<1.1 # See PR 1429 about upper bound
scipy>=1.4.1
semver~=2.13.0

Expand Down
51 changes: 51 additions & 0 deletions responsibleai/tests/causal/test_causal_general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.


from responsibleai import RAIInsights

from ..causal_manager_validator import _check_causal_result
from ..common_utils import create_adult_income_dataset


def test_causal_classification_scikitlearn_issue():
# This test gets stuck on SciKit-Learn v1.1.0
# See PR #1429
data_train, data_test, _, _, categorical_features, \
_, target_name, classes = create_adult_income_dataset()

rai_i = RAIInsights(
model=None,
train=data_train,
test=data_test,
task_type='classification',
target_column=target_name,
categorical_features=categorical_features,
classes=classes
)
assert rai_i is not None

treatment_features = ["age", "gender"]
cat_expansion = 49
rai_i.causal.add(treatment_features=treatment_features,
heterogeneity_features=["marital_status"],
nuisance_model="automl",
heterogeneity_model="forest",
alpha=0.06,
upper_bound_on_cat_expansion=cat_expansion,
treatment_cost=[0.1, 0.2],
min_tree_leaf_samples=2,
skip_cat_limit_checks=False,
categories="auto",
n_jobs=1,
verbose=1,
random_state=100,
)

rai_i.compute()

results = rai_i.causal.get()
assert results is not None
assert isinstance(results, list)
assert len(results) == 1
_check_causal_result(results[0])

0 comments on commit 016afd9

Please sign in to comment.