Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hyperparameter GridSearch with CausalForestDML #397

Closed
jaronowitz opened this issue Feb 2, 2021 · 7 comments
Closed

Hyperparameter GridSearch with CausalForestDML #397

jaronowitz opened this issue Feb 2, 2021 · 7 comments

Comments

@jaronowitz
Copy link
Contributor

Is there a way to use scikit learn's GridSearchCV with CausalForestDML? In the EconML user guide under How do I select the hyperparameters of the final model (if any)? is an example for this with NonParamDML but CausalForestDML does not have a model_final parameter. I'd like to fine tune n_estimators, max_depth, etc parameters of CausalForestDML.

When I try

est = CausalForestDML(inference='bootstrap', cv=5, 
                          model_t=WeightedMultiTaskLassoCV(), 
                          model_y=WeightedLassoCV())
    
est_dict[t] = GridSearchCV(estimator=est, param_grid=cv_parameters)
                        
col_list = make_col_list_polynomial_3(t)
T = T_df[col_list]
    
est_dict[t].fit(Y=y, T=T, X=X, W=W)

I get the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<timed exec> in <module>()

/opt/conda/anaconda/lib/python3.6/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     61             extra_args = len(args) - len(all_args)
     62             if extra_args <= 0:
---> 63                 return f(*args, **kwargs)
     64 
     65             # extra_args > 0

/opt/conda/anaconda/lib/python3.6/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
    763         n_splits = cv_orig.get_n_splits(X, y, groups)
    764 
--> 765         base_estimator = clone(self.estimator)
    766 
    767         parallel = Parallel(n_jobs=self.n_jobs,

/opt/conda/anaconda/lib/python3.6/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     61             extra_args = len(args) - len(all_args)
     62             if extra_args <= 0:
---> 63                 return f(*args, **kwargs)
     64 
     65             # extra_args > 0

/opt/conda/anaconda/lib/python3.6/site-packages/sklearn/base.py in clone(estimator, safe)
     69                                 "estimator as it does not implement a "
     70                                 "'get_params' method."
---> 71                                 % (repr(estimator), type(estimator)))
     72 
     73     klass = estimator.__class__

TypeError: Cannot clone object '<econml.dml.causal_forest.CausalForestDML object at 0x7f6c1284cd68>' (type <class 'econml.dml.causal_forest.CausalForestDML'>): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' method.
@vsyrgkanis
Copy link
Collaborator

You cant use gcv itself but in the 0.9 release the CausalForestDML class will have a “tune” method that internally tunes parameters and you can define whatever grid of params you want similar to gcv.

@jaronowitz
Copy link
Contributor Author

Thanks for the reply! Sounds like there is a lot of great new features in 0.9. Bring it on :)

@vsyrgkanis
Copy link
Collaborator

If you want to use it you can install the vasilis/grf_and_scorer_docs
Branch of the git repo which has not yet been merged to master. i.e.

pip install git+https://github.com/microsoft/EconML.git@vasilis/grf_and_scorer_docs

@jaronowitz
Copy link
Contributor Author

Hi @vsyrgkanis I was checking out the methods for CausalForrestDML (https://econml.azurewebsites.net/_autosummary/econml.dml.CausalForestDML.html) now that .9 is released. Is the tune method included in this release? Thanks for the help

@vsyrgkanis
Copy link
Collaborator

Unfortunately, it didn't get in the base 0.9 release. It'll get in, in an upcoming point release maybe next week. See here for the PR:
#390

@vsyrgkanis
Copy link
Collaborator

This is now addressed by #390 and will be included in the 0.9.2 release

@vsyrgkanis
Copy link
Collaborator

0.9.2 has been released on pypi and hyperparam tuning is available

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants