Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cython implementation of GRF and CausalForestDML #341

Merged
merged 129 commits into from
Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
5b88c6a
added backend option in causal forest
vasilismsr Nov 19, 2020
ff88cde
added backend option in causal forest
vasilismsr Nov 19, 2020
b20f20f
adding verbosity, restructuring static functions
vasilismsr Nov 20, 2020
57e4fe0
started cython causal forests
vasilismsr Nov 29, 2020
a387f9c
added version of grf criterion that optimizes mse and reweights heter…
vasilismsr Dec 1, 2020
82ae834
notebook update
vasilismsr Dec 1, 2020
42eaa1c
added a heterogeneity importance score that seems more robust for grf…
vasilismsr Dec 1, 2020
341c34e
added notebooks that implement the IV grf and the local linear grf.
vasilismsr Dec 2, 2020
28c77b9
added local linear IV
vasilismsr Dec 2, 2020
ce2c1e8
added python wrapper classes for grf
vasilismsr Dec 4, 2020
66ebf66
clean notebook
vasilismsr Dec 4, 2020
526f59f
clean notebook
vasilismsr Dec 4, 2020
2ea9441
switched forest prediction to the correct mean over mean
vasilismsr Dec 4, 2020
8e2aa88
made het criterion also proxy_children_impurity=True
vasilismsr Dec 4, 2020
a8496f9
made het criterion also proxy_children_impurity=True
vasilismsr Dec 4, 2020
fbd5166
fixed randomness in train_test split. Was using same exact seed as ra…
vasilismsr Dec 4, 2020
e682493
notebook
vasilismsr Dec 4, 2020
c331701
simplified cython code to deal with val and train
vasilismsr Dec 4, 2020
ef4fd82
implemented first version of predict interval
vasilismsr Dec 5, 2020
aae5229
added halfsample inference
vasilismsr Dec 5, 2020
da46c7e
cleaned up notebook
vasilismsr Dec 5, 2020
9a5f246
speed improvements
vasilismsr Dec 5, 2020
e0aff49
notebook
vasilismsr Dec 5, 2020
ca22b95
more speed ups
vasilismsr Dec 6, 2020
f340e34
speed ups
vasilismsr Dec 6, 2020
5e1112d
further speed ups. squeezing it out :)
vasilismsr Dec 6, 2020
e11edce
switched the way tree fitting is spawn for speed improvement
vasilismsr Dec 7, 2020
acb5a6e
fixed variable importance. Changed het criterion to only use the rele…
vasilismsr Dec 7, 2020
7c08f72
fixed balance of halfsamples and subsamples.
vasilismsr Dec 7, 2020
5a59ad6
added option to not fit intercept
vasilismsr Dec 7, 2020
9ca65a5
added variance correction. added n_subforests parameter.
vasilismsr Dec 7, 2020
8e3e5b3
added variance correction. added n_subforests parameter.
vasilismsr Dec 7, 2020
2fb998f
fixed variance correction. added objective bayes debiasing. changed d…
vasilismsr Dec 8, 2020
45eb277
added regression forest
vasilismsr Dec 8, 2020
7a68ca9
cleaning up unused code
vasilismsr Dec 8, 2020
60b1854
reverted back to return whole covaraince matrix in forest class, as i…
vasilismsr Dec 9, 2020
6e1c588
added CausalForestDML cate estimator
vasilismsr Dec 9, 2020
ed38370
fixed bug with negative effect inference due to non-bayesian-debiasin…
vasilismsr Dec 9, 2020
4424b5b
removed redundant comment from cate_estimators
vasilismsr Dec 9, 2020
f21f125
deprecating older CausalForest
vasilismsr Dec 10, 2020
b6949ee
increased level of confidence in CF and ORF notebook
vasilismsr Dec 10, 2020
369c1f3
increased level of confidence in CF and ORF notebook
vasilismsr Dec 10, 2020
54cd503
notebook
vasilismsr Dec 10, 2020
050d5a4
restructured dml into folder. Deprecated ForestDML by CausalForestDML…
vasilismsr Dec 10, 2020
0e1a080
rlearner imports
vasilismsr Dec 10, 2020
ce62c99
testing notebooks
vasilismsr Dec 10, 2020
7b8ccbd
dml mnotebooks
vasilismsr Dec 10, 2020
fbf498e
deprecating ensemble.SubsampledHonestForest
vasilismsr Dec 10, 2020
528c365
added test module. made drlearner use the non dprecated regression fo…
vasilismsr Dec 10, 2020
d196b5f
added tests for tree module and for cython grf. enabled min_var_leaf …
vasilismsr Dec 12, 2020
d806546
fixed random vector in fast min eigv and fast max eigv
vasilismsr Dec 12, 2020
2036887
added tests for grf cython code.
vasilismsr Dec 12, 2020
147851b
added tests for all cython grf code. Fixed small bug in lstsq function.
vasilismsr Dec 12, 2020
649da10
fixed small issues with random seeds. Added tests for causal/regressi…
vasilismsr Dec 13, 2020
37129e6
better var leaf checking, using pairwise determinant
vasilismsr Dec 13, 2020
fcb4a24
fixed small bug in eigenclipping in criterion.
vasilismsr Dec 14, 2020
122ae12
added tests for python grf code
vasilismsr Dec 14, 2020
084f3b2
finished tests for grf python. fixed small bugs.
vasilismsr Dec 14, 2020
6b60345
finished tests for grf python. fixed small bugs.
vasilismsr Dec 14, 2020
b979966
added tests for CausalForestDML
vasilismsr Dec 14, 2020
26646e7
Merge branch 'master' into vasilis/grf_simplification
vasilismsr Dec 14, 2020
ec8f810
Merge branch 'master' into vasilis/grf_simplification
vsyrgkanis Dec 14, 2020
4ccdede
deprecating subsampledhonestforest
vasilismsr Dec 14, 2020
2d2d1d4
changed orf notebook to non-deprecated stuff
vasilismsr Dec 15, 2020
e1cfeee
rerun notebook
vasilismsr Dec 15, 2020
72e227a
rerun notebook
vasilismsr Dec 15, 2020
a9a36ba
added comments
vasilismsr Dec 15, 2020
246df34
added comments in all tree module
vasilismsr Dec 15, 2020
ce3f4a6
more commenting
vasilismsr Dec 15, 2020
9378aef
Merge branch 'master' into vasilis/grf_simplification
vsyrgkanis Dec 16, 2020
1506ddf
comments and docstrings
vasilismsr Dec 16, 2020
18dbeca
comments and small parameter re-structuring
vasilismsr Dec 16, 2020
9e63df8
more comments in _base_grftree.py
vasilismsr Dec 17, 2020
75e8a93
comments
vasilismsr Dec 17, 2020
c928ab7
more comments and docstrings
vasilismsr Dec 17, 2020
7021281
finished all comments and docstrings
vasilismsr Dec 17, 2020
0f1a76e
merged master
vasilismsr Dec 17, 2020
e4ce16d
Enable setuptools build process
kbattocchi Dec 17, 2020
8041b14
linting
vasilismsr Dec 17, 2020
c88bd32
removing old testing notebooks
vasilismsr Dec 17, 2020
930ed44
fixed deprecated references in tests
vasilismsr Dec 18, 2020
e9c1079
fixed shapes of drlearner
vasilismsr Dec 18, 2020
a2c7a0b
added cython to install_requires
vasilismsr Dec 18, 2020
307edfd
fixed flaky random_state test
vasilismsr Dec 18, 2020
8b06947
fixed tests and api consistency
vasilismsr Dec 18, 2020
4e11b26
enabled sample_var as an argument
vasilismsr Dec 18, 2020
80d1621
fixed flaky test
vasilismsr Dec 18, 2020
987fa47
made inference class of cfdml private.
vasilismsr Dec 18, 2020
6097192
changed setup.cfg
vasilismsr Dec 18, 2020
79b12a2
fixed docstrings. changed docs to new classes. updated tables and lib…
vasilismsr Dec 18, 2020
e13da02
fixed centering on flowchart
vasilismsr Dec 18, 2020
74d071f
added cython to setup to
vasilismsr Dec 18, 2020
606a96a
changed notebook name and cleaned up. changed docstring in honest for…
vasilismsr Dec 18, 2020
c7f94d5
changed notebook name and cleaned up. changed docstring in honest for…
vasilismsr Dec 18, 2020
8e4c09a
fixed docs and dosctrings
vasilismsr Dec 18, 2020
2ada246
Remove Cython from package dependencies
kbattocchi Dec 18, 2020
3ad8023
fixed docstrings and test
vasilismsr Dec 18, 2020
7d7b812
docstests
vasilismsr Dec 18, 2020
b969b2f
added deprecate positional ot casualforestdml
vasilismsr Dec 18, 2020
d9c5302
linting
vasilismsr Dec 18, 2020
9741ace
added more explicit license clause from sklearn
vasilismsr Dec 19, 2020
0bd667a
added iterator and getitem in multioutputGRF and CausalForestDML
vasilismsr Dec 25, 2020
44e46c9
merged
vasilismsr Dec 25, 2020
d4f2a14
allowed sklearn 0.24.
vasilismsr Dec 25, 2020
73166eb
linting plus small shap changes
vasilismsr Dec 25, 2020
d48c0d6
doc conf point to 0.24
vasilismsr Dec 25, 2020
089702a
fixed random state bug with sklearn update. fixed bug in change in shap
vasilismsr Dec 25, 2020
ec23e9c
fixed _cross_val_predict
vasilismsr Dec 25, 2020
94aac46
enabled causalforestdml
vasilismsr Dec 25, 2020
c9c37b2
linting
vasilismsr Dec 25, 2020
7f1faa8
fixed notebooks
vasilismsr Dec 26, 2020
a13ef98
added option for max background samples to make computation more reas…
vasilismsr Dec 26, 2020
521b497
rerun notebooks
vasilismsr Dec 26, 2020
60162b2
fixed error_score param in gcvlist due to sklearn upgrade
vasilismsr Dec 26, 2020
423b57b
added shap cells in DML notebook
vasilismsr Dec 26, 2020
4d31729
fixed default value in dosctring of shap
vasilismsr Dec 26, 2020
bda9ec7
added shap values to GRF notebook
vasilismsr Dec 26, 2020
36e94a1
fixed bug in the way input_feature_names where used in summary. enabl…
vasilismsr Dec 26, 2020
60848a6
updated readme. removed autoreload from noteoboks
vasilismsr Dec 27, 2020
fe11879
added shap specific notebook
vasilismsr Dec 27, 2020
8a88046
added cell with custom grf in grf notebook
vasilismsr Dec 28, 2020
2d1bd97
updated dowhy notebook
vasilismsr Dec 28, 2020
4ea749c
removed distutils lefotver command in causla tree
vasilismsr Jan 6, 2021
060a48c
shap review comments
vasilismsr Jan 6, 2021
b56271d
addressing maggie's comments
vasilismsr Jan 6, 2021
84f76d1
partially addressing keith's review
vasilismsr Jan 6, 2021
44d6fda
fixed fall_back function
vasilismsr Jan 7, 2021
79ee0b1
fixed class names in inference
vasilismsr Jan 7, 2021
dbd35c9
addressed moprescu comments
vasilismsr Jan 8, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,40 @@
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE


Parts of this software, in particular code contained in the modules econml.tree and
econml.grf contain files that are forks from the scikit-learn git repository, or code
snippets from that repository:
https://github.com/scikit-learn/scikit-learn
published under the following License.

BSD 3-Clause License

Copyright (c) 2007-2020 The scikit-learn developers.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
45 changes: 23 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,7 @@ To install from source, see [For Developers](#for-developers) section below.
treatment_effects = est.effect(X_test)
lb, ub = est.effect_interval(X_test, alpha=0.05) # Confidence intervals via debiased lasso
```

* Forest last stage

```Python
from econml.dml import ForestDML
from sklearn.ensemble import GradientBoostingRegressor

est = ForestDML(model_y=GradientBoostingRegressor(), model_t=GradientBoostingRegressor())
est.fit(Y, T, X=X, W=W)
treatment_effects = est.effect(X_test)
# Confidence intervals via Bootstrap-of-Little-Bags for forests
lb, ub = est.effect_interval(X_test, alpha=0.05)
```

* Generic Machine Learning last stage

```Python
Expand All @@ -152,16 +139,16 @@ To install from source, see [For Developers](#for-developers) section below.
<summary>Causal Forests (click to expand)</summary>

```Python
from econml.causal_forest import CausalForest
from econml.dml import CausalForestDML
from sklearn.linear_model import LassoCV
# Use defaults
est = CausalForest()
est = CausalForestDML()
# Or specify hyperparameters
est = CausalForest(n_trees=500, min_leaf_size=10,
max_depth=10, subsample_ratio=0.7,
lambda_reg=0.01,
discrete_treatment=False,
model_T=LassoCV(), model_Y=LassoCV())
est = CausalForestDML(criterion='het', n_estimators=500,
min_samples_leaf=10,
max_depth=10, max_samples=0.5,
discrete_treatment=False,
model_t=LassoCV(), model_y=LassoCV())
est.fit(Y, T, X=X, W=W)
treatment_effects = est.effect(X_test)
# Confidence intervals via Bootstrap-of-Little-Bags for forests
Expand Down Expand Up @@ -354,7 +341,7 @@ treatment_effects = est.effect(X_test)

<details>
<summary>Policy Interpreter of the CATE model (click to expand)</summary>

```Python
from econml.cate_interpreter import SingleTreePolicyInterpreter
# We find a tree-based treatment policy based on the CATE model
Expand All @@ -366,7 +353,21 @@ treatment_effects = est.effect(X_test)
plt.show()
```
![image](notebooks/images/dr_policy_tree.png)


</details>

<details>
<summary>SHAP values for the CATE model (click to expand)</summary>

```Python
import shap
from econml.dml import CausalForestDML
est = CausalForestDML()
est.fit(Y, T, X=X, W=W)
shap_values = est.shap_values(X)
shap.summary_plot(shap_values['Y0']['T0'])
```

</details>

### Inference
Expand Down
4 changes: 2 additions & 2 deletions azure-pipelines-steps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

parameters:
body: []
package: '.'
package: '-e .'

steps:
- task: UsePythonVersion@0
Expand All @@ -24,7 +24,7 @@ steps:
condition: and(succeeded(), eq(variables['Agent.OS'], 'Linux'))

# Install the package
- script: 'python -m pip install --upgrade pip && pip install --upgrade setuptools wheel && pip install ${{ parameters.package }}'
- script: 'python -m pip install --upgrade pip && pip install --upgrade setuptools wheel Cython && pip install ${{ parameters.package }}'
displayName: 'Install dependencies'

- ${{ parameters.body }}
5 changes: 1 addition & 4 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ jobs:
- script: 'pip install --force-reinstall --no-cache-dir shap'
displayName: 'Install public shap'

- script: 'pip install --force-reinstall scikit-learn==0.23.2'
displayName: 'Install public old sklearn'

- script: 'python setup.py build_sphinx -W'
displayName: 'Build documentation'

Expand All @@ -81,7 +78,7 @@ jobs:

- script: 'python setup.py build_sphinx -b doctest'
displayName: 'Run doctests'
package: '.[automl]'
package: '-e .[automl]'

- job: 'Notebooks'
dependsOn: 'EvalChanges'
Expand Down
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
'sklearn': ('https://scikit-learn.org/0.23/', None),
'sklearn': ('https://scikit-learn.org/stable/', None),
'matplotlib': ('https://matplotlib.org/', None)}

# -- Options for todo extension ----------------------------------------------
Expand Down
28 changes: 14 additions & 14 deletions doc/map.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 7 additions & 6 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@ Public Module Reference
:toctree: _autosummary

econml.bootstrap
econml.cate_estimator
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved
econml.cate_interpreter
econml.causal_forest
econml.causal_tree
econml.deepiv
econml.dgp
econml.dml
econml.drlearner
econml.grf
econml.inference
econml.metalearners
econml.ortho_forest
Expand All @@ -27,7 +24,12 @@ Private Module Reference
:toctree: _autosummary

econml._ortho_learner
econml._rlearner
econml._cate_estimator
econml._causal_tree
econml.dml._rlearner
econml.grf._base_grf
econml.grf._base_grftree
econml.grf._criterion

Scikit-Learn Extensions
=======================
Expand All @@ -37,4 +39,3 @@ Scikit-Learn Extensions

econml.sklearn_extensions.linear_model
econml.sklearn_extensions.model_selection
econml.sklearn_extensions.ensemble
17 changes: 13 additions & 4 deletions doc/spec/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ The latter translates to estimating a local gradient around a treatment vector c
\partial\tau(\vec{t}, \vec{x}) = \E\left[\nabla_{\vec{t}} Y(\vec{t}) | X=\vec{x}\right] \tag{marginal CATE}

We will refer to the latter as the *heterogeneous marginal effect*. [1]_
Finally, we might not only be interested in the effect but also in the actual *counterfactual prediction*, i.e. estimating the quatity:

.. math ::
\mu(\vec{t}, \vec{x}) = \E\left[Y(\vec{t}) | X=\vec{x}\right] \tag{counterfactual prediction}

We assume we have data that are generated from some collection policy. In particular, we assume that we have data of the form:
:math:`\{Y_i(T_i), T_i, X_i, W_i, Z_i\}`, where :math:`Y_i(T_i)` is the observed outcome for the chosen treatment,
Expand All @@ -43,6 +39,19 @@ The variables :math:`X_i` can also be thought of as *control* variables, but the
they are a subset of the controls with respect to which we want to measure treatment effect heterogeneity.
We will refer to them as *features*.

Finally, some times we might not only be interested in the effect but also in the actual *counterfactual prediction*, i.e. estimating the quatity:

.. math ::
\mu(\vec{t}, \vec{x}) = \E\left[Y(\vec{t}) | X=\vec{x}\right] \tag{counterfactual prediction}

Our package does not offer support for counterfactual prediction. However, for most of our estimators (the ones
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved
assuming a linear-in-treatment model), counterfactual prediction can be easily constructed by combining any baseline predictive model
with our causal effect model, i.e. train any machine learning model :math:`b(\vec{t}, \vec{x})` to solve the regression/classification
problem :math:`\E[Y | T=\vec{t}, X=\vec{x}]`, and then set :math:`\mu(vec{t}, \vec{x}) = \tau(\vec{t}, T, \vec{x}) + b(T, \vec{x})`,
where :math:`T` is either the observed treatment for that sample under the observational policy or the treatment
that the observational policy would have assigned to that sample. These auxiliary ML models can be trained
with any machine learning package outside of EconML.

.. rubric::
Structural Equation Formulation

Expand Down
6 changes: 3 additions & 3 deletions doc/spec/comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ Detailed estimator comparison
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.LinearDRLearner` | Categorical | | Yes | | Projected | | Yes | |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ForestDML` | 1-d/Binary | | Yes | Yes | | Yes | | Yes |
| :class:`.CausalForestDML` | Any | | Yes | Yes | | Yes | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ForestDRLearner` | Categorical | | Yes | | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ContinuousTreatmentOrthoForest` | Continuous | | Yes | Yes | | | Yes | Yes |
| :class:`.DMLOrthoForest` | Any | | Yes | Yes | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.DiscreteTreatmentOrthoForest` | Categorical | | Yes | | | | Yes | Yes |
| :class:`.DROrthoForest` | Categorical | | Yes | | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :mod:`~econml.metalearners` | Categorical | | | | | Yes | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
Expand Down
Loading