diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index ad805eef733..fe97053a424 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -662,6 +662,10 @@ def __init__( self._n_classes: int = -1 self.set_params(**kwargs) + # scikit-learn 1.6 introduced an __sklearn__tags() method intended to replace _more_tags(). + # _more_tags() can be removed whenever lightgbm's minimum supported scikit-learn version + # is >=1.6. + # ref: https://github.com/microsoft/LightGBM/pull/6651 def _more_tags(self) -> Dict[str, Any]: return { "allow_nan": True, @@ -673,6 +677,15 @@ def _more_tags(self) -> Dict[str, Any]: }, } + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + more_tags = self._more_tags() + tags.input_tags.allow_nan = more_tags["allow_nan"] + tags.input_tags.sparse = "sparse" in more_tags["X_types"] + tags.target_tags.one_d_labels = "1dlabels" in more_tags["X_types"] + tags._xfail_checks = more_tags["_xfail_checks"] + return tags + def __sklearn_is_fitted__(self) -> bool: return getattr(self, "fitted_", False)