Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transform argument to plots #1036

Merged
merged 17 commits into from
Feb 6, 2020
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
* Integrated rcParams `plot.point_estimate` (#994), `stats.ic_scale` (#993) and `stats.credible_interval` (#1017)
* Added `group` argument to `plot_ppc` (#1008), `plot_pair` (#1009) and `plot_joint` (#1012)
* Add `skipna` argument to `hpd` and `summary` (#1035)
* Added `transform` argument to `plot_trace`, `plot_forest`, `plot_pair`, `plot_posterior`, `plot_rank`, `plot_parallel`, `plot_violin`,`plot_density`, `plot_joint` (#1036)


### Maintenance and fixes
* Fixed bug in extracting prior samples for cmdstanpy (#979)
Expand Down
5 changes: 5 additions & 0 deletions arviz/plots/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def plot_density(
group="posterior",
data_labels=None,
var_names=None,
transform=None,
credible_interval=None,
point_estimate="auto",
colors="cycle",
Expand Down Expand Up @@ -58,6 +59,8 @@ def plot_density(
List of variables to plot. If multiple datasets are supplied and var_names is not None,
will print the same set of variables for each dataset. Defaults to None, which results in
all the variables being plotted.
transform : callable
Function to transform data (defaults to None i.e. the identity function)
credible_interval : float
Credible intervals. Should be in the interval (0, 1]. Defaults to 0.94.
point_estimate : Optional[str]
Expand Down Expand Up @@ -147,6 +150,8 @@ def plot_density(

>>> az.plot_density([centered, non_centered], var_names=["mu"], bw=.9)
"""
if transform is not None:
data = transform(data)
if not isinstance(data, (list, tuple)):
datasets = [convert_to_dataset(data, group=group)]
else:
Expand Down
5 changes: 5 additions & 0 deletions arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def plot_forest(
kind="forestplot",
model_names=None,
var_names=None,
transform=None,
coords=None,
combined=False,
credible_interval=None,
Expand Down Expand Up @@ -48,6 +49,8 @@ def plot_forest(
var_names: list[str], optional
List of variables to plot (defaults to None, which results in all
variables plotted)
transform : callable
Function to transform data (defaults to None i.e.the identity function)
coords : dict, optional
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
combined : bool
Expand Down Expand Up @@ -135,6 +138,8 @@ def plot_forest(
"""
if not isinstance(data, (list, tuple)):
data = [data]
if transform is not None:
data = transform(data)

if coords is None:
coords = {}
Expand Down
6 changes: 6 additions & 0 deletions arviz/plots/jointplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def plot_joint(
data,
group="posterior",
var_names=None,
transform=None,
coords=None,
figsize=None,
textsize=None,
Expand Down Expand Up @@ -35,6 +36,8 @@ def plot_joint(
var_names : str or iterable of str
Variables to be plotted. iter of two variables or one variable (with subset having
exactly 2 dimensions) are required.
transform : callable
Function to transform data (defaults to None i.e. the identity function)
coords : mapping, optional
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
figsize : tuple
Expand Down Expand Up @@ -136,6 +139,9 @@ def plot_joint(

data = convert_to_dataset(data, group=group)

if transform is not None:
data = transform(data)

if coords is None:
coords = {}

Expand Down
5 changes: 5 additions & 0 deletions arviz/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
def plot_posterior(
data,
var_names=None,
transform=None,
coords=None,
figsize=None,
textsize=None,
Expand Down Expand Up @@ -45,6 +46,8 @@ def plot_posterior(
Refer to documentation of az.convert_to_dataset for details
var_names : list of variable names
Variables to be plotted, two variables are required.
transform : callable
Function to transform data (defaults to None i.e.the identity function)
coords : mapping, optional
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
figsize : tuple
Expand Down Expand Up @@ -177,6 +180,8 @@ def plot_posterior(
>>> az.plot_posterior(data, var_names=['mu'], credible_interval=.75)
"""
data = convert_to_dataset(data, group=group)
if transform is not None:
data = transform(data)
var_names = _var_names(var_names, data)

if coords is None:
Expand Down
5 changes: 5 additions & 0 deletions arviz/plots/rankplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
def plot_rank(
data,
var_names=None,
transform=None,
coords=None,
bins=None,
kind="bars",
Expand Down Expand Up @@ -50,6 +51,8 @@ def plot_rank(
az.convert_to_dataset for details
var_names : string or list of variable names
Variables to be plotted
transform : callable
Function to transform data (defaults to None i.e.the identity function)
coords : mapping, optional
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
bins : None or passed to np.histogram
Expand Down Expand Up @@ -116,6 +119,8 @@ def plot_rank(
>>> az.plot_rank(noncentered_data, var_names="mu", kind='vlines', axes=ax[1])

"""
if transform is not None:
data = transform(data)
posterior_data = convert_to_dataset(data, group="posterior")
if coords is not None:
posterior_data = posterior_data.sel(**coords)
Expand Down
7 changes: 7 additions & 0 deletions arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
def plot_trace(
data,
var_names=None,
transform=None,
coords=None,
divergences="bottom",
figsize=None,
Expand Down Expand Up @@ -46,6 +47,8 @@ def plot_trace(
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
divergences : {"bottom", "top", None, False}
Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y.
transform : callable
Function to transform data (defaults to None i.e.the identity function)
figsize : figure size tuple
If None, size is (12, variables * 2)
rug : bool
Expand Down Expand Up @@ -137,6 +140,10 @@ def plot_trace(
divergence_data = False

data = get_coords(convert_to_dataset(data, group="posterior"), coords)

if transform is not None:
data = transform(data)

var_names = _var_names(var_names, data)

if lines is None:
Expand Down
5 changes: 5 additions & 0 deletions arviz/plots/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def plot_violin(
data,
var_names=None,
transform=None,
quartiles=True,
rug=False,
credible_interval=None,
Expand Down Expand Up @@ -43,6 +44,8 @@ def plot_violin(
Refer to documentation of az.convert_to_dataset for details
var_names: list, optional
List of variables to plot (defaults to None, which results in all variables plotted)
transform : callable
Function to transform data (defaults to None i.e. the identity function)
quartiles : bool, optional
Flag for plotting the interquartile range, in addition to the credible_interval*100%
intervals. Defaults to True
Expand Down Expand Up @@ -86,6 +89,8 @@ def plot_violin(
axes : matplotlib axes or bokeh figures
"""
data = convert_to_dataset(data, group="posterior")
if transform is not None:
data = transform(data)
var_names = _var_names(var_names, data)

plotters = filter_plotters_list(
Expand Down