diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 6823756a60db..26e8f62fe30f 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -1221,13 +1221,29 @@ def fit( {_lgbmmodel_doc_custom_eval_note} """ - def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: + def predict( + self, + X: _DaskMatrixLike, + raw_score: bool = False, + start_iteration: int = 0, + num_iteration: Optional[int] = None, + pred_leaf: bool = False, + pred_contrib: bool = False, + validate_features: bool = False, + **kwargs: Any + ) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" return _predict( model=self.to_local(), data=X, dtype=self.classes_.dtype, client=_get_dask_client(self.client), + raw_score=raw_score, + start_iteration=start_iteration, + num_iteration=num_iteration, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + validate_features=validate_features, **kwargs ) @@ -1394,12 +1410,28 @@ def fit( {_lgbmmodel_doc_custom_eval_note} """ - def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: + def predict( + self, + X: _DaskMatrixLike, + raw_score: bool = False, + start_iteration: int = 0, + num_iteration: Optional[int] = None, + pred_leaf: bool = False, + pred_contrib: bool = False, + validate_features: bool = False, + **kwargs: Any + ) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" return _predict( model=self.to_local(), data=X, client=_get_dask_client(self.client), + raw_score=raw_score, + start_iteration=start_iteration, + num_iteration=num_iteration, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + validate_features=validate_features, **kwargs ) @@ -1552,12 +1584,28 @@ def fit( {_lgbmmodel_doc_custom_eval_note} """ - def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: + def predict( + self, + X: _DaskMatrixLike, + raw_score: bool = False, + start_iteration: int = 0, + num_iteration: Optional[int] = None, + pred_leaf: bool = False, + pred_contrib: bool = False, + validate_features: bool = False, + **kwargs: Any + ) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRanker.predict.""" return _predict( model=self.to_local(), data=X, client=_get_dask_client(self.client), + raw_score=raw_score, + start_iteration=start_iteration, + num_iteration=num_iteration, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + validate_features=validate_features, **kwargs )