Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
vuillaut committed May 15, 2021
1 parent 78519ad commit 6ebe975
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ctaplot/plots/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..ana import ana
from ..io.dataset import load_any_resource
from sklearn.metrics import precision_recall_curve, PrecisionRecallDisplay
from sklearn.metrics import recall_score, precision_score

__all__ = ['plot_resolution',
'plot_resolution_difference',
Expand Down Expand Up @@ -2132,6 +2133,8 @@ def plot_precision_recall(y_true, proba_pred, pos_label=0, sample_weigth=None, t
prec, recall, thresholds = precision_recall_curve(y_true, proba_pred, pos_label=pos_label,
sample_weight=sample_weigth)

pr_display = PrecisionRecallDisplay(precision=prec, recall=recall, pos_label=pos_label).plot(ax=ax, **kwargs)

if threshold is not None:
pred = (proba_pred > threshold).astype(int)
neg_label = list(set(y_true))
Expand All @@ -2144,4 +2147,4 @@ def plot_precision_recall(y_true, proba_pred, pos_label=0, sample_weigth=None, t
p = precision_score(y_true, pred_labels, pos_label=pos_label)
pr_display.ax_.scatter(r, p)

return PrecisionRecallDisplay(precision=prec, recall=recall, pos_label=pos_label).plot(ax=ax, **kwargs)
return pr_display

0 comments on commit 6ebe975

Please sign in to comment.