diff --git a/CHANGELOG.md b/CHANGELOG.md index a1a7582d66..78bdc95a30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ * Fix `xlabels` in `plot_elpd` ([1601](https://github.com/arviz-devs/arviz/pull/1601)) * Renamed `sample` dim to `__sample__` when stacking `chain` and `draw` to avoid dimension collision ([1647](https://github.com/arviz-devs/arviz/pull/1647)) * Removed the `circular` argument in `plot_dist` in favor of `is_circular` ([1681](https://github.com/arviz-devs/arviz/pull/1681)) +* Fix `legend` argument in `plot_separation` ([1701](https://github.com/arviz-devs/arviz/pull/1701)) ### Deprecation * Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201)) diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py index 78c153b3ec..28ef3c68cd 100644 --- a/arviz/plots/backends/bokeh/separationplot.py +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -15,7 +15,7 @@ def plot_separation( figsize, textsize, color, - legend, # pylint: disable=unused-argument + legend, locs, width, ax, @@ -53,6 +53,13 @@ def plot_separation( exp_events_kwargs.setdefault("color", "black") exp_events_kwargs.setdefault("size", 15) + if legend: + y_hat_line_kwargs.setdefault("legend_label", label_y_hat) + exp_events_kwargs.setdefault( + "legend_label", + "Expected events", + ) + figsize, *_ = _scale_fig_size(figsize, textsize) idx = np.argsort(y_hat) @@ -79,7 +86,6 @@ def plot_separation( ax.line( np.linspace(0, 1, len(y_hat)), y_hat[idx], - legend_label=label_y_hat, **y_hat_line_kwargs, ) @@ -88,7 +94,6 @@ def plot_separation( ax.triangle( y_hat[idx][len(y_hat) - expected_events - 1], 0, - legend_label="Expected events", **exp_events_kwargs, ) diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py index 3b8bc5bb57..528f89ddb0 100644 --- a/arviz/plots/backends/matplotlib/separationplot.py +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -80,7 +80,7 @@ def plot_separation( **exp_events_kwargs ) - if legend and expected_events or y_hat_line: + if legend and (expected_events or y_hat_line): handles, labels = ax.get_legend_handles_labels() labels_dict = dict(zip(labels, handles)) ax.legend(labels_dict.values(), labels_dict.keys())