Skip to content

Commit

Permalink
[python-package] Fix mypy errors for predict() method (#5678)
Browse files Browse the repository at this point in the history
  • Loading branch information
IdoKendo authored Jan 19, 2023
1 parent 3c3f79e commit 7af85ce
Showing 1 changed file with 51 additions and 3 deletions.
54 changes: 51 additions & 3 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,13 +1221,29 @@ def fit(
{_lgbmmodel_doc_custom_eval_note}
"""

def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
def predict(
self,
X: _DaskMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
return _predict(
model=self.to_local(),
data=X,
dtype=self.classes_.dtype,
client=_get_dask_client(self.client),
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)

Expand Down Expand Up @@ -1394,12 +1410,28 @@ def fit(
{_lgbmmodel_doc_custom_eval_note}
"""

def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
def predict(
self,
X: _DaskMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
return _predict(
model=self.to_local(),
data=X,
client=_get_dask_client(self.client),
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)

Expand Down Expand Up @@ -1552,12 +1584,28 @@ def fit(
{_lgbmmodel_doc_custom_eval_note}
"""

def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
def predict(
self,
X: _DaskMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(
model=self.to_local(),
data=X,
client=_get_dask_client(self.client),
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)

Expand Down

0 comments on commit 7af85ce

Please sign in to comment.