From 7af85cee692132cf6695bf5832b905f261291d68 Mon Sep 17 00:00:00 2001 From: IdoKendo <41922392+IdoKendo@users.noreply.github.com> Date: Thu, 19 Jan 2023 03:23:07 +0200 Subject: [PATCH] [python-package] Fix mypy errors for predict() method (#5678) --- python-package/lightgbm/dask.py | 54 +++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 6823756a60db..26e8f62fe30f 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -1221,13 +1221,29 @@ def fit( {_lgbmmodel_doc_custom_eval_note} """ - def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: + def predict( + self, + X: _DaskMatrixLike, + raw_score: bool = False, + start_iteration: int = 0, + num_iteration: Optional[int] = None, + pred_leaf: bool = False, + pred_contrib: bool = False, + validate_features: bool = False, + **kwargs: Any + ) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" return _predict( model=self.to_local(), data=X, dtype=self.classes_.dtype, client=_get_dask_client(self.client), + raw_score=raw_score, + start_iteration=start_iteration, + num_iteration=num_iteration, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + validate_features=validate_features, **kwargs ) @@ -1394,12 +1410,28 @@ def fit( {_lgbmmodel_doc_custom_eval_note} """ - def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: + def predict( + self, + X: _DaskMatrixLike, + raw_score: bool = False, + start_iteration: int = 0, + num_iteration: Optional[int] = None, + pred_leaf: bool = False, + pred_contrib: bool = False, + validate_features: bool = False, + **kwargs: Any + ) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" return _predict( model=self.to_local(), data=X, client=_get_dask_client(self.client), + raw_score=raw_score, + start_iteration=start_iteration, + num_iteration=num_iteration, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + validate_features=validate_features, **kwargs ) @@ -1552,12 +1584,28 @@ def fit( {_lgbmmodel_doc_custom_eval_note} """ - def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: + def predict( + self, + X: _DaskMatrixLike, + raw_score: bool = False, + start_iteration: int = 0, + num_iteration: Optional[int] = None, + pred_leaf: bool = False, + pred_contrib: bool = False, + validate_features: bool = False, + **kwargs: Any + ) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRanker.predict.""" return _predict( model=self.to_local(), data=X, client=_get_dask_client(self.client), + raw_score=raw_score, + start_iteration=start_iteration, + num_iteration=num_iteration, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + validate_features=validate_features, **kwargs )