Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Khat deprecate annotate in favor of threshold #1478

Merged
merged 4 commits into from
Dec 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
* Switch to `compact=True` by default in our plots ([1468](https://github.com/arviz-devs/arviz/issues/1468))
* `plot_elpd`, avoid modifying the input dict ([1477](https://github.com/arviz-devs/arviz/issues/1477))


### Deprecation
* `plot_khat` deprecate `annotate` argument in favor of `threshold`. The new argument accepts floats ([1478](https://github.com/arviz-devs/arviz/issues/1478))

### Documentation
* Reorganize documentation and change sphinx theme ([1406](https://github.com/arviz-devs/arviz/pull/1406))
Expand Down
36 changes: 21 additions & 15 deletions arviz/plots/backends/bokeh/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ def plot_khat(
xdata,
khats,
kwargs,
annotate,
threshold,
coord_labels,
show_hlines,
show_bins,
hlines_kwargs, # pylint: disable=unused-argument
hlines_kwargs,
xlabels, # pylint: disable=unused-argument
legend, # pylint: disable=unused-argument
color,
Expand All @@ -49,6 +50,10 @@ def plot_khat(

(figsize, *_, line_width, _) = _scale_fig_size(figsize, textsize)

if hlines_kwargs is None:
hlines_kwargs = {}
hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1])

cmap = None
if isinstance(color, str):
if color in dims:
Expand Down Expand Up @@ -103,21 +108,21 @@ def plot_khat(
fill_alpha=alphas,
)

if annotate:
idxs = xdata[khats > 1]
if threshold is not None:
idxs = xdata[khats > threshold]
for idx in idxs:
ax.text(x=[idx], y=[khats[idx]], text=[coord_labels[idx]])

for hline in [0, 0.5, 0.7, 1]:
_hline = Span(
location=hline,
dimension="width",
line_color="grey",
line_width=line_width,
line_dash="dashed",
)

ax.renderers.append(_hline)
if show_hlines:
for hline in hlines_kwargs.pop("hlines"):
_hline = Span(
location=hline,
dimension="width",
line_color="grey",
line_width=line_width,
line_dash="dashed",
)
ax.renderers.append(_hline)

ymin = min(khats)
ymax = max(khats)
Expand All @@ -134,14 +139,15 @@ def plot_khat(
text=[bin_format.format(count, count / n_data_points * 100)],
)
ax.x_range._property_values["end"] = xmax + 1 # pylint: disable=protected-access

ax.xaxis.axis_label = "Data Point"
ax.yaxis.axis_label = "Shape parameter k"

if ymin > 0:
ax.y_range._property_values["start"] = -0.02 # pylint: disable=protected-access
if ymax < 1:
ax.y_range._property_values["end"] = 1.02 # pylint: disable=protected-access
elif ymax > 1 & annotate:
elif ymax > 1 & threshold:
ax.y_range._property_values["end"] = 1.1 * ymax # pylint: disable=protected-access

show_layout(ax, show)
Expand Down
17 changes: 11 additions & 6 deletions arviz/plots/backends/matplotlib/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def plot_khat(
xdata,
khats,
kwargs,
annotate,
threshold,
coord_labels,
show_hlines,
show_bins,
hlines_kwargs,
xlabels,
Expand Down Expand Up @@ -61,6 +62,7 @@ def plot_khat(
backend_kwargs["squeeze"] = True

hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines")
hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1])
hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"])
hlines_kwargs.setdefault("alpha", 0.7)
hlines_kwargs.setdefault("zorder", -1)
Expand Down Expand Up @@ -109,8 +111,8 @@ def plot_khat(

sc_plot = ax.scatter(xdata, khats, c=rgba_c, **kwargs)

if annotate:
idxs = xdata[khats > 1]
if threshold is not None:
idxs = xdata[khats > threshold]
for idx in idxs:
ax.text(
idx,
Expand All @@ -125,10 +127,15 @@ def plot_khat(
if show_bins:
xmax += n_data_points / 12
ylims1 = ax.get_ylim()
ax.hlines([0, 0.5, 0.7, 1], xmin=xmin, xmax=xmax, linewidth=linewidth, **hlines_kwargs)
ylims2 = ax.get_ylim()
ymin = min(ylims1[0], ylims2[0])
ymax = min(ylims1[1], ylims2[1])

if show_hlines:
ax.hlines(
hlines_kwargs.pop("hlines"), xmin=xmin, xmax=xmax, linewidth=linewidth, **hlines_kwargs
)

if show_bins:
bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax])
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
Expand All @@ -141,8 +148,6 @@ def plot_khat(
horizontalalignment="center",
verticalalignment="center",
)
ax.set_ylim(ymin, ymax)
ax.set_xlim(xmin, xmax)

ax.set_xlabel("Data Point", fontsize=ax_labelsize)
ax.set_ylabel(r"Shape parameter k", fontsize=ax_labelsize)
Expand Down
24 changes: 19 additions & 5 deletions arviz/plots/khatplot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Pareto tail indices plot."""
import logging

import numpy as np
from xarray import DataArray

Expand All @@ -7,14 +9,18 @@
from ..utils import get_coords
from .plot_utils import format_coords_as_labels, get_plotting_function

_log = logging.getLogger(__name__)


def plot_khat(
khats,
color="C0",
xlabels=False,
show_hlines=False,
show_bins=False,
bin_format="{1:.1f}%",
annotate=False,
threshold=None,
hover_label=False,
hover_format="{1}",
figsize=None,
Expand Down Expand Up @@ -42,12 +48,15 @@ def plot_khat(
otherwise, it will be interpreted as a list of the dims to be used for the color code
xlabels : bool, optional
Use coords as xticklabels
show_hlines : bool, optional
Show the horizontal lines, by default at the values [0, 0.5, 0.7, 1].
show_bins : bool, optional
Show the number of khats which fall in each bin.
Show the percentage of khats falling in each bin, as delimited by hlines.
bin_format : str, optional
The string is used as formatting guide calling ``bin_format.format(count, pct)``.
annotate : bool, optional
Show the labels of k values larger than 1.
threshold : float, optional
Show the labels of k values larger than threshold. Defaults to `None`,
no observations will be highlighted.
hover_label : bool, optional
Show the datapoint label when hovering over it with the mouse. Requires an interactive
backend.
Expand Down Expand Up @@ -103,7 +112,7 @@ def plot_khat(

>>> centered_eight = az.load_arviz_data("centered_eight")
>>> khats = az.loo(centered_eight, pointwise=True).pareto_k
>>> az.plot_khat(khats, xlabels=True, annotate=True)
>>> az.plot_khat(khats, xlabels=True, threshold=1)

Use custom color scheme

Expand All @@ -117,6 +126,10 @@ def plot_khat(
>>> az.plot_khat(loo_radon, color=colors)

"""
if annotate:
_log.warning("annotate will be deprecated, please use threshold instead")
threshold = annotate

if coords is None:
coords = {}

Expand Down Expand Up @@ -152,8 +165,9 @@ def plot_khat(
xdata=xdata,
khats=khats,
kwargs=kwargs,
annotate=annotate,
threshold=threshold,
coord_labels=coord_labels,
show_hlines=show_hlines,
show_bins=show_bins,
hlines_kwargs=hlines_kwargs,
xlabels=xlabels,
Expand Down
18 changes: 14 additions & 4 deletions arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,12 @@ def test_plot_joint_bad(models):
{"color": "obs_dim", "legend": True, "hover_label": True},
{"color": "blue", "coords": {"obs_dim": slice(2, 4)}},
{"color": np.random.uniform(size=8), "show_bins": True},
{"color": np.random.uniform(size=(8, 3)), "show_bins": True, "annotate": True},
{
"color": np.random.uniform(size=(8, 3)),
"show_bins": True,
"show_hlines": True,
"threshold": 1,
},
],
)
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
Expand All @@ -628,7 +633,12 @@ def test_plot_khat(models, input_type, kwargs):
{"color": "dim2", "legend": True, "hover_label": True},
{"color": "blue", "coords": {"dim2": slice(2, 4)}},
{"color": np.random.uniform(size=35), "show_bins": True},
{"color": np.random.uniform(size=(35, 3)), "show_bins": True, "annotate": True},
{
"color": np.random.uniform(size=(35, 3)),
"show_bins": True,
"show_hlines": True,
"threshold": 1,
},
],
)
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
Expand All @@ -650,9 +660,9 @@ def test_plot_khat_multidim(multidim_models, input_type, kwargs):
assert axes


def test_plot_khat_annotate():
def test_plot_khat_threshold():
khats = np.array([0, 0, 0.6, 0.6, 0.8, 0.9, 0.9, 2, 3, 4, 1.5])
axes = plot_khat(khats, annotate=True, backend="bokeh", show=False)
axes = plot_khat(khats, threshold=1, backend="bokeh", show=False)
assert axes


Expand Down
18 changes: 14 additions & 4 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,12 @@ def test_plot_elpd_one_model(models):
{"color": "obs_dim", "legend": True, "hover_label": True},
{"color": "blue", "coords": {"obs_dim": slice(2, 4)}},
{"color": np.random.uniform(size=8), "show_bins": True},
{"color": np.random.uniform(size=(8, 3)), "show_bins": True, "annotate": True},
{
"color": np.random.uniform(size=(8, 3)),
"show_bins": True,
"show_hlines": True,
"threshold": 1,
},
],
)
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
Expand All @@ -1165,7 +1170,12 @@ def test_plot_khat(models, input_type, kwargs):
{"color": "dim2", "legend": True, "hover_label": True},
{"color": "blue", "coords": {"dim2": slice(2, 4)}},
{"color": np.random.uniform(size=35), "show_bins": True},
{"color": np.random.uniform(size=(35, 3)), "show_bins": True, "annotate": True},
{
"color": np.random.uniform(size=(35, 3)),
"show_bins": True,
"show_hlines": True,
"threshold": 1,
},
],
)
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
Expand All @@ -1187,9 +1197,9 @@ def test_plot_khat_multidim(multidim_models, input_type, kwargs):
assert axes


def test_plot_khat_annotate():
def test_plot_khat_threshold():
khats = np.array([0, 0, 0.6, 0.6, 0.8, 0.9, 0.9, 2, 3, 4, 1.5])
axes = plot_khat(khats, annotate=True)
axes = plot_khat(khats, threshold=1)
assert axes


Expand Down