Skip to content

Commit

Permalink
[python-package] [dask] fix mypy errors regarding predict_proba (#5728)
Browse files Browse the repository at this point in the history
  • Loading branch information
IdoKendo authored Feb 19, 2023
1 parent f136de4 commit f975d3f
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,13 +1255,29 @@ def predict(
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]"
)

def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
def predict_proba(
self,
X: _DaskMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
return _predict(
model=self.to_local(),
data=X,
pred_proba=True,
client=_get_dask_client(self.client),
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)

Expand Down

0 comments on commit f975d3f

Please sign in to comment.