Skip to content

Commit

Permalink
Fix legend in separation plot (#1701)
Browse files Browse the repository at this point in the history
* small fix for legend

* fix bokeh plot

* update changelog
  • Loading branch information
agustinaarroyuelo authored May 25, 2021
1 parent 7ebedd2 commit aab6b1d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
* Fix `xlabels` in `plot_elpd` ([1601](https://github.com/arviz-devs/arviz/pull/1601))
* Renamed `sample` dim to `__sample__` when stacking `chain` and `draw` to avoid dimension collision ([1647](https://github.com/arviz-devs/arviz/pull/1647))
* Removed the `circular` argument in `plot_dist` in favor of `is_circular` ([1681](https://github.com/arviz-devs/arviz/pull/1681))
* Fix `legend` argument in `plot_separation` ([1701](https://github.com/arviz-devs/arviz/pull/1701))

### Deprecation
* Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201))
Expand Down
11 changes: 8 additions & 3 deletions arviz/plots/backends/bokeh/separationplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def plot_separation(
figsize,
textsize,
color,
legend, # pylint: disable=unused-argument
legend,
locs,
width,
ax,
Expand Down Expand Up @@ -53,6 +53,13 @@ def plot_separation(
exp_events_kwargs.setdefault("color", "black")
exp_events_kwargs.setdefault("size", 15)

if legend:
y_hat_line_kwargs.setdefault("legend_label", label_y_hat)
exp_events_kwargs.setdefault(
"legend_label",
"Expected events",
)

figsize, *_ = _scale_fig_size(figsize, textsize)

idx = np.argsort(y_hat)
Expand All @@ -79,7 +86,6 @@ def plot_separation(
ax.line(
np.linspace(0, 1, len(y_hat)),
y_hat[idx],
legend_label=label_y_hat,
**y_hat_line_kwargs,
)

Expand All @@ -88,7 +94,6 @@ def plot_separation(
ax.triangle(
y_hat[idx][len(y_hat) - expected_events - 1],
0,
legend_label="Expected events",
**exp_events_kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/backends/matplotlib/separationplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def plot_separation(
**exp_events_kwargs
)

if legend and expected_events or y_hat_line:
if legend and (expected_events or y_hat_line):
handles, labels = ax.get_legend_handles_labels()
labels_dict = dict(zip(labels, handles))
ax.legend(labels_dict.values(), labels_dict.keys())
Expand Down

0 comments on commit aab6b1d

Please sign in to comment.