diff --git a/doc/docstrings/histplot.ipynb b/doc/docstrings/histplot.ipynb index ed4ef6e85b..99ed6c551d 100644 --- a/doc/docstrings/histplot.ipynb +++ b/doc/docstrings/histplot.ipynb @@ -461,9 +461,9 @@ ], "metadata": { "kernelspec": { - "display_name": "seaborn-refactor (py38)", + "display_name": "seaborn-py38-latest", "language": "python", - "name": "seaborn-refactor" + "name": "seaborn-py38-latest" }, "language_info": { "codemirror_mode": { diff --git a/doc/docstrings/stripplot.ipynb b/doc/docstrings/stripplot.ipynb new file mode 100644 index 0000000000..a88d94b807 --- /dev/null +++ b/doc/docstrings/stripplot.ipynb @@ -0,0 +1,313 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set_theme(style=\"whitegrid\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a single numeric variable shows its univariate distribution with points randomly \"jittered\" on the other axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "sns.stripplot(data=tips, x=\"total_bill\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a second variable splits the strips of poins to compare categorical levels of that variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Show vertically-oriented strips by swapping the assignment of the categorical and numerical variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"day\", y=\"total_bill\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Prior to version 0.12, the levels of the categorical variable had different colors. To get the same effect, assign the `hue` variable explicitly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"day\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Or you can assign a distinct variable to `hue` to show a multidimensional relationship:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "If the `hue` variable is numeric, it will be mapped with a quantitative palette by default (this was not the case prior to version 0.12):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Use `palette` to control the color mapping, including forcing a categorical mapping by passing the name of a qualitative palette:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\", palette=\"deep\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, the different levels of the `hue` variable are intermingled in each strip, but setting `dodge=True` will split them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\", dodge=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The random jitter can be disabled by setting `jitter=False`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\", dodge=True, jitter=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If plotting in wide-form mode, each column of the dataframe will be mapped to both `x` and `hue`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To change the orientation while in wide-form mode, pass `orient` explicitly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, orient=\"h\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The `orient` parameter is also useful when both axis variables are numeric, as it will resolve ambiguity about which dimension to group (and jitter) along:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"size\", orient=\"h\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, the categorical variable will be mapped to discrete indices with a fixed scale (0, 1, ...), even when it is numeric:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(\n", + " data=tips.query(\"size in [2, 3, 5]\"),\n", + " x=\"total_bill\", y=\"size\", orient=\"h\",\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To disable this behavior and use the original scale of the variable, set `fixed_scale=False`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(\n", + " data=tips.query(\"size in [2, 3, 5]\"),\n", + " x=\"total_bill\", y=\"size\", orient=\"h\",\n", + " fixed_scale=False,\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Further visual customization can be achieved by passing matplotlib keyword arguments:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(\n", + " data=tips, x=\"total_bill\", y=\"day\", hue=\"time\",\n", + " jitter=False, s=20, marker=\"D\", linewidth=1, alpha=.1,\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To make a plot with multiple facets, it is safer to use :func:`catplot` than to work with :class:`FacetGrid` directly, because :func:`catplot` will ensure that the categorical and hue variables are properly synchronized in each facet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.catplot(data=tips, x=\"time\", y=\"total_bill\", hue=\"sex\", col=\"day\", aspect=.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/releases/v0.12.0.txt b/doc/releases/v0.12.0.txt index 4163d56c4a..4a12cddbbb 100644 --- a/doc/releases/v0.12.0.txt +++ b/doc/releases/v0.12.0.txt @@ -6,6 +6,10 @@ v0.12.0 (Unreleased) - |Fix| |Enhancement| Improved robustness to missing data, including additional support for the `pd.NA` type (:pr:`2417). +- TODO function specific categorical enhancements, including: + + - In :func:`stripplot`, a "strip" with a single observation will be plotted without jitter (:pr:`2413`) + - Made `scipy` an optional dependency and added `pip install seaborn[all]` as a method for ensuring the availability of compatible `scipy` and `statsmodels` libraries at install time. This has a few minor implications for existing code, which are explained in the Github pull request (:pr:`2398`). - Following `NEP29 `_, dropped support for Python 3.6 and bumped the minimally-supported versions of the library dependencies. diff --git a/seaborn/_core.py b/seaborn/_core.py index eff314a552..083a2ea9b5 100644 --- a/seaborn/_core.py +++ b/seaborn/_core.py @@ -20,6 +20,7 @@ color_palette, ) from .utils import ( + _check_argument, get_color_cycle, remove_na, ) @@ -43,7 +44,7 @@ def __init__(self, plotter): # TODO Putting this here so we can continue to use a lot of the # logic that's built into the library, but the idea of this class - # is to move towards semantic mappings that are agnositic about the + # is to move towards semantic mappings that are agnostic about the # kind of plot they're going to be used to draw. # Fully achieving that is going to take some thinking. self.plotter = plotter @@ -602,6 +603,12 @@ class VectorPlotter: def __init__(self, data=None, variables={}): + self._var_levels = {} + # var_ordered is relevant only for categorical axis variables, and may + # be better handled by an internal axis information object that tracks + # such information and is set up by the scale_* methods. The analogous + # information for numeric axes would be information about log scales. + self._var_ordered = {"x": False, "y": False} # alt., used DefaultDict self.assign_variables(data, variables) for var, cls in self._semantic_mappings.items(): @@ -613,8 +620,6 @@ def __init__(self, data=None, variables={}): # Call the mapping function to initialize with default values getattr(self, f"map_{var}")() - self._var_levels = {} - @classmethod def get_semantics(cls, kwargs, semantics=None): """Subset a dictionary` arguments with known semantic variables.""" @@ -679,6 +684,12 @@ def assign_variables(self, data=None, variables={}): for v in variables } + # XXX does this make sense here? + for axis in "xy": + if axis not in variables: + continue + self.var_levels[axis] = categorical_order(self.plot_data[axis]) + return self def _assign_variables_wideform(self, data=None, **kwargs): @@ -938,7 +949,9 @@ def _assign_variables_longform(self, data=None, **kwargs): return plot_data, variables def iter_data( - self, grouping_vars=None, reverse=False, from_comp_data=False, + self, grouping_vars=None, *, + reverse=False, from_comp_data=False, + by_facet=True, allow_empty=False, ): """Generator for getting subsets of data defined by semantic variables. @@ -948,10 +961,15 @@ def iter_data( ---------- grouping_vars : string or list of strings Semantic variables that define the subsets of data. - reverse : bool, optional + reverse : bool If True, reverse the order of iteration. - from_comp_data : bool, optional + from_comp_data : bool If True, use self.comp_data rather than self.plot_data + by_facet : bool + If True, add faceting variables to the set of grouping variables. + allow_empty : bool + If True, yield an empty dataframe when no observations exist for + combinations of grouping variables. Yields ------ @@ -971,10 +989,11 @@ def iter_data( grouping_vars = list(grouping_vars) # Always insert faceting variables - facet_vars = {"col", "row"} - grouping_vars.extend( - facet_vars & set(self.variables) - set(grouping_vars) - ) + if by_facet: + facet_vars = {"col", "row"} + grouping_vars.extend( + facet_vars & set(self.variables) - set(grouping_vars) + ) # Reduce to the semantics used in this plot grouping_vars = [ @@ -986,6 +1005,26 @@ def iter_data( else: data = self.plot_data + levels = self.var_levels.copy() + if from_comp_data: + for axis in {"x", "y"} & set(grouping_vars): + if self.var_types[axis] == "categorical": + if self._var_ordered[axis]: + # If the axis is ordered, then the axes in a possible + # facet grid are by definition "shared", or there is a + # single axis with a unique cat -> idx mapping. + # So we can just take the first converter object. + converter = self.converters[axis].iloc[0] + levels[axis] = converter.convert_units(levels[axis]) + else: + # Otherwise, the mappings may not be unique, but we can + # use the unique set of index values in comp_data. + levels[axis] = np.sort(data[axis].unique()) + elif self.var_types[axis] == "datetime": + levels[axis] = mpl.dates.date2num(levels[axis]) + elif self.var_types[axis] == "numeric" and self._log_scaled(axis): + levels[axis] = np.log10(levels[axis]) + if grouping_vars: grouped_data = data.groupby( @@ -994,7 +1033,7 @@ def iter_data( grouping_keys = [] for var in grouping_vars: - grouping_keys.append(self.var_levels.get(var, [])) + grouping_keys.append(levels.get(var, [])) iter_keys = itertools.product(*grouping_keys) if reverse: @@ -1008,7 +1047,14 @@ def iter_data( try: data_subset = grouped_data.get_group(pd_key) except KeyError: - continue + if allow_empty: + # XXX we are adding this to allow backwards compatability + # with the empty artists that old categorical plots would + # add (before 0.12), which we may decide to break, in which + # case this option could be removed + data_subset = pd.DataFrame(columns=data.columns) + else: + continue sub_vars = dict(zip(grouping_vars, key)) @@ -1039,27 +1085,16 @@ def comp_data(self): if var not in self.variables: continue - # Get a corresponding axis object so that we can convert the units - # to matplotlib's numeric representation, which we can compute on - # This is messy and it would probably be better for VectorPlotter - # to manage its own converters (using the matplotlib tools). - # XXX Currently does not support unshared categorical axes! - # (But see comment in _attach about how those don't exist) - if self.ax is None: - ax = self.facets.axes.flat[0] - else: - ax = self.ax - axis = getattr(ax, f"{var}axis") - - # Use the converter assigned to the axis to get a float representation - # of the data, passing np.nan or pd.NA through (pd.NA becomes np.nan) - with pd.option_context('mode.use_inf_as_null', True): - orig = self.plot_data[var].dropna() - comp_col = pd.Series(index=orig.index, dtype=float, name=var) - comp_col.loc[orig.index] = pd.to_numeric(axis.convert_units(orig)) - - if axis.get_scale() == "log": - comp_col = np.log10(comp_col) + comp_col = pd.Series(index=self.plot_data.index, dtype=float, name=var) + grouped = self.plot_data[var].groupby(self.converters[var], sort=False) + for converter, orig in grouped: + with pd.option_context('mode.use_inf_as_null', True): + orig = orig.dropna() + comp = pd.to_numeric(converter.convert_units(orig)) + if converter.get_scale() == "log": + comp = np.log10(comp) + comp_col.loc[orig.index] = comp + comp_data.insert(0, var, comp_col) self._comp_data = comp_data @@ -1081,7 +1116,12 @@ def _get_axes(self, sub_vars): else: return self.ax - def _attach(self, obj, allowed_types=None, log_scale=None): + def _attach( + self, + obj, + allowed_types=None, + log_scale=None, + ): """Associate the plotter with an Axes manager and initialize its units. Parameters @@ -1111,13 +1151,21 @@ def _attach(self, obj, allowed_types=None, log_scale=None): self.facets = None ax_list = [obj] + # Identify which "axis" variables we have defined + axis_variables = set("xy").intersection(self.variables) + + # -- Verify the types of our x and y variables here. + # This doesn't really make complete sense being here here, but it's a fine + # place for it, given the current sytstem. + # (Note that for some plots, there might be more complicated restrictions) + # e.g. the categorical plots have their own check that as specific to the + # non-categorical axis. if allowed_types is None: allowed_types = ["numeric", "datetime", "categorical"] elif isinstance(allowed_types, str): allowed_types = [allowed_types] - for var in set("xy").intersection(self.variables): - # Check types of x/y variables + for var in axis_variables: var_type = self.var_types[var] if var_type not in allowed_types: err = ( @@ -1126,30 +1174,63 @@ def _attach(self, obj, allowed_types=None, log_scale=None): ) raise TypeError(err) - # Register with the matplotlib unit conversion machinery - # Perhaps cleaner to manage our own transform objects? - # XXX Currently this does not allow "unshared" categorical axes - # We could add metadata to a FacetGrid and set units based on that. - # See also comment in comp_data, which only uses a single axes to do - # its mapping, meaning that it won't handle unshared axes well either. - for ax in ax_list: - axis = getattr(ax, f"{var}axis") - seed_data = self.plot_data[var] - if var_type == "categorical": - seed_data = categorical_order(seed_data) - axis.update_units(seed_data) + # -- Get axis objects for each row in plot_data for type conversions and scaling - # For categorical y, we want the "first" level to be at the top of the axis - if self.var_types.get("y", None) == "categorical": - for ax in ax_list: - try: - ax.yaxis.set_inverted(True) - except AttributeError: # mpl < 3.1 - if not ax.yaxis_inverted(): - ax.invert_yaxis() + facet_dim = {"x": "col", "y": "row"} + + self.converters = {} + for var in axis_variables: + other_var = {"x": "y", "y": "x"}[var] + + converter = pd.Series(index=self.plot_data.index, name=var, dtype=object) + share_state = getattr(self.facets, f"_share{var}", True) + + # Simplest cases are that we have a single axes, all axes are shared, + # or sharing is only on the orthogonal facet dimension. In these cases, + # all datapoints get converted the same way, so use the first axis + if share_state is True or share_state == facet_dim[other_var]: + converter.loc[:] = getattr(ax_list[0], f"{var}axis") + + else: - # Possibly log-scale one or both axes - if log_scale is not None: + # Next simplest case is when no axes are shared, and we can + # use the axis objects within each facet + if share_state is False: + for axes_vars, axes_data in self.iter_data(): + ax = self._get_axes(axes_vars) + converter.loc[axes_data.index] = getattr(ax, f"{var}axis") + + # In the more complicated case, the axes are shared within each + # "file" of the facetgrid. In that case, we need to subset the data + # for that file and assign it the first axis in the slice of the grid + else: + + names = getattr(self.facets, f"{share_state}_names") + for i, level in enumerate(names): + idx = (i, 0) if share_state == "row" else (0, i) + axis = getattr(self.facets.axes[idx], f"{var}axis") + converter.loc[self.plot_data[share_state] == level] = axis + + # Store the converter vector, which we use elsewhere (e.g comp_data) + self.converters[var] = converter + + # Now actually update the matplotlib objects to do the conversion we want + grouped = self.plot_data[var].groupby(self.converters[var], sort=False) + for converter, seed_data in grouped: + if self.var_types[var] == "categorical": + if self._var_ordered[var]: + order = self.var_levels[var] + else: + order = None + seed_data = categorical_order(seed_data, order) + converter.update_units(seed_data) + + # -- Set numerical axis scales + + # First unpack the log_scale argument + if log_scale is None: + scalex = scaley = False + else: # Allow single value or x, y tuple try: scalex, scaley = log_scale @@ -1157,17 +1238,29 @@ def _attach(self, obj, allowed_types=None, log_scale=None): scalex = log_scale if "x" in self.variables else False scaley = log_scale if "y" in self.variables else False - for axis, scale in zip("xy", (scalex, scaley)): - if scale: - for ax in ax_list: - set_scale = getattr(ax, f"set_{axis}scale") - if scale is True: - set_scale("log") + # Now use it + for axis, scale in zip("xy", (scalex, scaley)): + if scale: + for ax in ax_list: + set_scale = getattr(ax, f"set_{axis}scale") + if scale is True: + set_scale("log") + else: + if LooseVersion(mpl.__version__) >= "3.3": + set_scale("log", base=scale) else: - if LooseVersion(mpl.__version__) >= "3.3": - set_scale("log", base=scale) - else: - set_scale("log", **{f"base{axis}": scale}) + set_scale("log", **{f"base{axis}": scale}) + + # For categorical y, we want the "first" level to be at the top of the axis + if self.var_types.get("y", None) == "categorical": + for ax in ax_list: + try: + ax.yaxis.set_inverted(True) + except AttributeError: # mpl < 3.1 + if not ax.yaxis_inverted(): + ax.invert_yaxis() + + # TODO -- Add axes labels def _log_scaled(self, axis): """Return True if specified axis is log scaled on all attached axes.""" @@ -1202,6 +1295,124 @@ def _add_axis_labels(self, ax, default_x="", default_y=""): y_visible = any(t.get_visible() for t in ax.get_yticklabels()) ax.set_ylabel(self.variables.get("y", default_y), visible=y_visible) + # XXX If the scale_* methods are going to modify the plot_data structure, they + # can't be called twice. That means that if they are called twice, they should + # raise. Alternatively, we could store an original version of plot_data and each + # time they are called they operate on the store, not the current state. + + def scale_native(self, axis, *args, **kwargs): + + # Default, defer to matplotlib + + raise NotImplementedError + + def scale_numeric(self, axis, *args, **kwargs): + + # Feels needed to completeness, what should it do? + # Perhaps handle log scaling? Set the ticker/formatter/limits? + + raise NotImplementedError + + def scale_datetime(self, axis, *args, **kwargs): + + # Use pd.to_datetime to convert strings or numbers to datetime objects + # Note, use day-resolution for numeric->datetime to match matplotlib + + raise NotImplementedError + + def scale_categorical(self, axis, order=None, formatter=None): + """ + Enforce categorical (fixed-scale) rules for the data on given axis. + + Parameters + ---------- + axis : "x" or "y" + Axis of the plot to operate on. + order : list + Order that unique values should appear in. + formatter : callable + Function mapping values to a string representation. + + Returns + ------- + self + + """ + # This method both modifies the internal representation of the data + # (converting it to string) and sets some attributes on self. It might be + # a good idea to have a separate object attached to self that contains the + # information in those attributes (i.e. whether to enforce variable order + # across facets, the order to use) similar to the SemanticMapping objects + # we have for semantic variables. That object could also hold the converter + # objects that get used, if we can decouple those from an existing axis + # (cf. https://github.com/matplotlib/matplotlib/issues/19229). + # There are some interactions with faceting information that would need + # to be thought through, since the converts to use depend on facets. + # If we go that route, these methods could become "borrowed" methods similar + # to what happens with the alternate semantic mapper constructors, although + # that approach is kind of fussy and confusing. + + # TODO this method could also set the grid state? Since we like to have no + # grid on the categorical axis by default. Again, a case where we'll need to + # store information until we use it, so best to have a way to collect the + # attributes that this method sets. + + # TODO if we are going to set visual properties of the axes with these methods, + # then we could do the steps currently in CategoricalPlotter._adjust_cat_axis + + # TODO another, and distinct idea, is to expose a cut= param here + + _check_argument("axis", ["x", "y"], axis) + + # Categorical plots can be "univariate" in which case they get an anonymous + # category label on the opposite axis. + if axis not in self.variables: + self.variables[axis] = None + self.var_types[axis] = "categorical" + self.plot_data[axis] = "" + + # If the "categorical" variable has a numeric type, sort the rows so that + # the default result from categorical_order has those values sorted after + # they have been coerced to strings. The reason for this is so that later + # we can get facet-wise orders that are correct. + # XXX Should this also sort datetimes? + # It feels more consistent, but technically will be a default change + # If so, should also change categorical_order to behave that way + if self.var_types[axis] == "numeric": + self.plot_data = self.plot_data.sort_values(axis, kind="mergesort") + + # Now get a reference to the categorical data vector + cat_data = self.plot_data[axis] + + # Get the initial categorical order, which we do before string + # conversion to respect the original types of the order list. + # Track whether the order is given explicitly so that we can know + # whether or not to use the order constructed here downstream + self._var_ordered[axis] = order is not None or cat_data.dtype.name == "category" + order = pd.Index(categorical_order(cat_data, order)) + + # Then convert data to strings. This is because in matplotlib, + # "categorical" data really mean "string" data, so doing this artists + # will be drawn on the categorical axis with a fixed scale. + # TODO implement formatter here; check that it returns strings? + if formatter is not None: + cat_data = cat_data.map(formatter) + order = order.map(formatter) + else: + cat_data = cat_data.astype(str) + order = order.astype(str) + + # Update the levels list with the type-converted order variable + self.var_levels[axis] = order + + # Now ensure that seaborn will use categorical rules internally + self.var_types[axis] = "categorical" + + # Put the string-typed categorical vector back into the plot_data structure + self.plot_data[axis] = cat_data + + return self + class VariableType(UserString): """ @@ -1364,7 +1575,14 @@ def infer_orient(x=None, y=None, orient=None, require_numeric=True): return "h" elif orient is not None: - raise ValueError(f"Value for `orient` not understood: {orient}") + err = ( + "`orient` must start with 'v' or 'h' or be None, " + f"but `{repr(orient)}` was passed." + ) + raise ValueError(err) + + elif x_type != "categorical" and y_type == "categorical": + return "h" elif x_type != "numeric" and y_type == "numeric": return "v" diff --git a/seaborn/_testing.py b/seaborn/_testing.py index abd933a068..138bdd8870 100644 --- a/seaborn/_testing.py +++ b/seaborn/_testing.py @@ -3,28 +3,7 @@ from numpy.testing import assert_array_equal -LINE_PROPS = [ - "alpha", - "color", - "linewidth", - "linestyle", - "xydata", - "zorder", -] - -COLLECTION_PROPS = [ - "alpha", - "edgecolor", - "facecolor", - "fill", - "hatch", - "linestyle", - "linewidth", - "paths", - "zorder", -] - -BAR_PROPS = [ +USE_PROPS = [ "alpha", "edgecolor", "facecolor", @@ -33,30 +12,36 @@ "height", "linestyle", "linewidth", + "paths", "xy", + "xydata", + "sizes", "zorder", ] -def assert_artists_equal(list1, list2, properties): +def assert_artists_equal(list1, list2): assert len(list1) == len(list2) for a1, a2 in zip(list1, list2): + assert a1.__class__ == a2.__class__ prop1 = a1.properties() prop2 = a2.properties() - for key in properties: + for key in USE_PROPS: + if key not in prop1: + continue v1 = prop1[key] v2 = prop2[key] if key == "paths": for p1, p2 in zip(v1, v2): assert_array_equal(p1.vertices, p2.vertices) assert_array_equal(p1.codes, p2.codes) - elif isinstance(v1, np.ndarray): - assert_array_equal(v1, v2) elif key == "color": v1 = mpl.colors.to_rgba(v1) v2 = mpl.colors.to_rgba(v2) assert v1 == v2 + elif isinstance(v1, np.ndarray): + assert_array_equal(v1, v2) else: assert v1 == v2 @@ -68,21 +53,18 @@ def assert_legends_equal(leg1, leg2): assert t1.get_text() == t2.get_text() assert_artists_equal( - leg1.get_patches(), leg2.get_patches(), BAR_PROPS, + leg1.get_patches(), leg2.get_patches(), ) assert_artists_equal( - leg1.get_lines(), leg2.get_lines(), LINE_PROPS, + leg1.get_lines(), leg2.get_lines(), ) def assert_plots_equal(ax1, ax2, labels=True): - assert_artists_equal(ax1.patches, ax2.patches, BAR_PROPS) - assert_artists_equal(ax1.lines, ax2.lines, LINE_PROPS) - - poly1 = ax1.findobj(mpl.collections.PolyCollection) - poly2 = ax2.findobj(mpl.collections.PolyCollection) - assert_artists_equal(poly1, poly2, COLLECTION_PROPS) + assert_artists_equal(ax1.patches, ax2.patches) + assert_artists_equal(ax1.lines, ax2.lines) + assert_artists_equal(ax1.collections, ax2.collections) if labels: assert ax1.get_xlabel() == ax2.get_xlabel() diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 701cb90254..dd6a327a28 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -443,6 +443,8 @@ def __init__( self._legend_data = {} self._x_var = None self._y_var = None + self._sharex = sharex + self._sharey = sharey self._dropna = dropna self._not_na = not_na diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 11fc735d70..57066c7285 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -18,7 +18,12 @@ import matplotlib.patches as Patches import matplotlib.pyplot as plt -from ._core import variable_type, infer_orient, categorical_order +from ._core import ( + VectorPlotter, + variable_type, + infer_orient, + categorical_order, +) from . import utils from .utils import remove_na, _normal_quantile_func from .algorithms import bootstrap @@ -35,6 +40,268 @@ ] +class _CategoricalPlotterNew(VectorPlotter): + + semantics = "x", "y", "hue", "units" + + wide_structure = {"x": "@columns", "y": "@values", "hue": "@columns"} + flat_structure = {"x": "@index", "y": "@values"} + + def __init__( + self, + data=None, + variables={}, + order=None, + orient=None, + require_numeric=False, + fixed_scale=True, + ): + + super().__init__(data=data, variables=variables) + + # This method takes care of some bookkeeping that is necessary because the + # original categorical plots (prior to the 2021 refactor) had some rules that + # don't fit exactly into the logic of _core. It may be wise to have a second + # round of refactoring that moves the logic deeper, but this will keep things + # relatively sensible for now. + + # The concept of an "orientation" is important to the original categorical + # plots, but there's no provision for it in _core, so we need to do it here. + # Note that it could be useful for the other functions in at least two ways + # (orienting a univariate distribution plot from long-form data and selecting + # the aggregation axis in lineplot), so we may want to eventually refactor it. + self.orient = infer_orient( + x=self.plot_data.get("x", None), + y=self.plot_data.get("y", None), + orient=orient, + require_numeric=require_numeric, + ) + + # Short-circuit in the case of an empty plot + if not self.has_xy_data: + return + + # For wide data, orient determines assignment to x/y differently from the + # wide_structure rules in _core. If we do decide to make orient part of the + # _core variable assignment, we'll want to figure out how to express that. + if self.input_format == "wide" and self.orient == "h": + self.plot_data = self.plot_data.rename(columns={"x": "y", "y": "x"}) + orig_x, orig_x_type = self.variables["x"], self.var_types["x"] + orig_y, orig_y_type = self.variables["y"], self.var_types["y"] + self.variables.update({"x": orig_y, "y": orig_x}) + self.var_types.update({"x": orig_y_type, "y": orig_x_type}) + + def _hue_backcompat(self, color, palette, hue_order, force_hue=False): + """Implement backwards compatability for hue parametrization. + + Note: the force_hue parameter is used so that functions can be shown to + pass existing tests during refactoring and then tested for new behavior. + It can be removed after completion of the work. + + """ + # The original categorical functions applied a palette to the categorical axis + # by default. We want to require an explicit hue mapping, to be more consistent + # with how things work elsewhere now. I don't think there's any good way to + # do this gently -- because it's triggered by the default value of hue=None, + # users would always get a warning, unless we introduce some sentinel "default" + # argument for this change. That's possible, but asking users to set `hue=None` + # on every call is annoying. + # We are keeping the logic for implementing the old behavior in with the current + # system so that (a) we can punt on that decision and (b) we can ensure that + # refactored code passes old tests. + default_behavior = color is None or palette is not None + if force_hue and "hue" not in self.variables and default_behavior: + self._redundant_hue = True + self.plot_data["hue"] = self.plot_data[self.cat_axis] + self.variables["hue"] = self.variables[self.cat_axis] + self.var_types["hue"] = "categorical" + hue_order = self.var_levels[self.cat_axis] + + # Because we convert the categorical axis variable to string, + # we need to update a dictionary palette too + if isinstance(palette, dict): + palette = {str(k): v for k, v in palette.items()} + + else: + self._redundant_hue = False + + # Previously, categorical plots had a trick where color= could seed the palette. + # Because that's an explicit parameterization, we are going to give it one + # release cycle with a warning before removing. + if "hue" in self.variables and palette is None and color is not None: + if not isinstance(color, str): + color = mpl.colors.to_hex(color) + palette = f"dark:{color}" + msg = ( + "Setting a gradient palette using color= is deprecated and will be " + f"removed in version 0.13. Set `palette='{palette}'` for same effect." + ) + warnings.warn(msg, FutureWarning) + + return palette, hue_order + + @property + def cat_axis(self): + return {"v": "x", "h": "y"}[self.orient] + + def _get_gray(self, color="C0"): + """Get a grayscale value that looks good with color.""" + if "hue" in self.variables: + rgb_colors = list(self._hue_map.lookup_table.values()) + else: + rgb_colors = [mpl.colors.to_rgb(color)] + + light_vals = [colorsys.rgb_to_hls(*mpl.colors.to_rgb(c))[1] for c in rgb_colors] + lum = min(light_vals) * .6 + gray = mpl.colors.rgb2hex((lum, lum, lum)) + return gray + + def _adjust_cat_axis(self, ax, axis): + """Set ticks and limits for a categorical variable.""" + # Note: in theory, this could happen in _attach for all categorical axes + # But two reasons not to do that: + # - If it happens before plotting, autoscaling messes up the plot limits + # - It would change existing plots from other seaborn functions + if self.var_types[axis] != "categorical": + return + + data = self.plot_data[axis] + if self.facets is not None: + share_group = getattr(ax, f"get_shared_{axis}_axes")() + shared_axes = [getattr(ax, f"{axis}axis")] + [ + getattr(other_ax, f"{axis}axis") + for other_ax in self.facets.axes.flat + if share_group.joined(ax, other_ax) + ] + data = data[self.converters[axis].isin(shared_axes)] + + if self._var_ordered[axis]: + order = categorical_order(data, self.var_levels[axis]) + else: + order = categorical_order(data) + + if axis == "x": + ax.xaxis.grid(False) + ax.set_xlim(-.5, len(order) - .5, auto=None) + else: + ax.yaxis.grid(False) + # Note limits that correspond to previously-inverted y axis + ax.set_ylim(len(order) - .5, -.5, auto=None) + + def plot_strips( + self, + jitter, + dodge, + color, + plot_kws, + ): + + # XXX 2021 refactor notes + # note, original categorical plots do not follow the cycle! + # They probably should ... but no changes in this first round of refactoring + # if self.ax is None: + # default_color = "C0" + # else: + # scout = self.ax.scatter([], [], color=color, **plot_kws) + # default_color = scout.get_facecolors() + # scout.remove() + default_color = "C0" if color is None else color + + # TODO this should be centralized + unique_values = np.unique(self.comp_data[self.cat_axis]) + if len(unique_values) > 1: + native_width = np.nanmin(np.diff(unique_values)) + else: + native_width = 1 + width = .8 * native_width + + if jitter is True: + jlim = 0.1 + else: + jlim = float(jitter) + if "hue" in self.variables and dodge: + jlim /= len(self._hue_map.levels) + jlim *= native_width + jitterer = partial(np.random.uniform, low=-jlim, high=+jlim) + + # XXX this is a property on the original class and probably broadly useful + if "hue" in self.variables: + n_levels = len(self._hue_map.levels) + if dodge: + each_width = width / n_levels + offsets = np.linspace(0, width - each_width, n_levels) + offsets -= offsets.mean() + else: + offsets = np.zeros(n_levels) + else: + dodge = False + + # Note that stripplot iterates over categorical positions (and hue levels only + # in the case of dodged strips) to match the original way artists were added. + iter_vars = [self.cat_axis] + if dodge: + iter_vars.append("hue") + + # Note further that, unlike most modern functions, stripplot adds empty + # artists for combinations of variables that have no observations, hence the + # addition/use of allow_empty in iter_data during the 2021 refactor. + + # Initialize ax as otherwise we won't get it when not looping over hue. + # If we are in a faceted context, this will be None, but _get_axes will + # return an Axes later. Perhaps _get_axes should have some awareness of + # cases when x/y are part of the iter_data grouper? + ax = self.ax + + for sub_vars, sub_data in self.iter_data(iter_vars, + from_comp_data=True, + allow_empty=True): + + sub_data = sub_data.dropna() + + if dodge: + dodge_move = offsets[sub_data["hue"].map(self._hue_map.levels.index)] + else: + dodge_move = 0 + + if jitter and len(sub_data) > 1: + jitter_move = jitterer(size=len(sub_data)) + else: + jitter_move = 0 + + sub_data = sub_data.assign(**{ + self.cat_axis: sub_data[self.cat_axis] + dodge_move + jitter_move + }) + + if "hue" in self.variables: + c = self._hue_map(sub_data["hue"]) + else: + c = mpl.colors.to_hex(default_color) + + for var in "xy": + if self._log_scaled(var): + sub_data[var] = np.power(10, sub_data[var]) + + ax = self._get_axes(sub_vars) + ax.scatter(sub_data["x"], sub_data["y"], c=c, **plot_kws) + + # TODO XXX remove redundant hue or always define and use when legend is "auto" + show_legend = not self._redundant_hue and self.input_format != "wide" + if "hue" in self.variables and show_legend: # TODO and legend: + # XXX 2021 refactor notes + # As we know, legends are an ongoing challenge. + # I'm duplicating the old approach here, but I don't love it, + # and it doesn't handle numeric hue mapping properly + for level in self._hue_map.levels: + color = self._hue_map(level) + ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level) + ax.legend(loc="best", title=self.variables["hue"]) + + +class _CategoricalFacetPlotter(_CategoricalPlotterNew): + + semantics = _CategoricalPlotterNew.semantics + ("col", "row") + + class _CategoricalPlotter(object): width = .8 @@ -1089,79 +1356,6 @@ def add_legend_data(self, ax): s=60) -class _StripPlotter(_CategoricalScatterPlotter): - """1-d scatterplot with categorical organization.""" - def __init__(self, x, y, hue, data, order, hue_order, - jitter, dodge, orient, color, palette): - """Initialize the plotter.""" - self.establish_variables(x, y, hue, data, orient, order, hue_order) - self.establish_colors(color, palette, 1) - - # Set object attributes - self.dodge = dodge - self.width = .8 - - if jitter == 1: # Use a good default for `jitter = True` - jlim = 0.1 - else: - jlim = float(jitter) - if self.hue_names is not None and dodge: - jlim /= len(self.hue_names) - self.jitterer = partial(np.random.uniform, low=-jlim, high=+jlim) - - def draw_stripplot(self, ax, kws): - """Draw the points onto `ax`.""" - palette = np.asarray(self.colors) - for i, group_data in enumerate(self.plot_data): - if self.plot_hues is None or not self.dodge: - - if self.hue_names is None: - hue_mask = np.ones(group_data.size, bool) - else: - hue_mask = np.array([h in self.hue_names - for h in self.plot_hues[i]], bool) - # Broken on older numpys - # hue_mask = np.in1d(self.plot_hues[i], self.hue_names) - - strip_data = group_data[hue_mask] - point_colors = np.asarray(self.point_colors[i][hue_mask]) - - # Plot the points in centered positions - cat_pos = np.ones(strip_data.size) * i - cat_pos += self.jitterer(size=len(strip_data)) - kws.update(c=palette[point_colors]) - if self.orient == "v": - ax.scatter(cat_pos, strip_data, **kws) - else: - ax.scatter(strip_data, cat_pos, **kws) - - else: - offsets = self.hue_offsets - for j, hue_level in enumerate(self.hue_names): - hue_mask = self.plot_hues[i] == hue_level - strip_data = group_data[hue_mask] - - point_colors = np.asarray(self.point_colors[i][hue_mask]) - - # Plot the points in centered positions - center = i + offsets[j] - cat_pos = np.ones(strip_data.size) * center - cat_pos += self.jitterer(size=len(strip_data)) - kws.update(c=palette[point_colors]) - if self.orient == "v": - ax.scatter(cat_pos, strip_data, **kws) - else: - ax.scatter(strip_data, cat_pos, **kws) - - def plot(self, ax, kws): - """Make the plot.""" - self.draw_stripplot(ax, kws) - self.add_legend_data(ax) - self.annotate_axes(ax) - if self.orient == "h": - ax.invert_yaxis() - - class _SwarmPlotter(_CategoricalScatterPlotter): def __init__(self, x, y, hue, data, order, hue_order, @@ -2794,30 +2988,67 @@ def stripplot( order=None, hue_order=None, jitter=True, dodge=False, orient=None, color=None, palette=None, size=5, edgecolor="gray", linewidth=0, ax=None, + hue_norm=None, fixed_scale=True, formatter=None, **kwargs ): - if "split" in kwargs: - dodge = kwargs.pop("split") - msg = "The `split` parameter has been renamed to `dodge`." - warnings.warn(msg, UserWarning) + # XXX we need to add a legend= param!!! + + p = _CategoricalPlotterNew( # TODO update name on switchover + data=data, + variables=_CategoricalPlotterNew.get_semantics(locals()), + order=order, + orient=orient, + require_numeric=False, + fixed_scale=fixed_scale, + ) - plotter = _StripPlotter(x, y, hue, data, order, hue_order, - jitter, dodge, orient, color, palette) if ax is None: ax = plt.gca() + if fixed_scale or p.var_types[p.cat_axis] == "categorical": + p.scale_categorical(p.cat_axis, order=order, formatter=formatter) + + p._attach(ax) + + if not p.has_xy_data: + return ax + + palette, hue_order = p._hue_backcompat(color, palette, hue_order) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + # XXX Copying possibly bad default decisions from original code for now kwargs.setdefault("zorder", 3) size = kwargs.get("s", size) - if linewidth is None: - linewidth = size / 10 + + # XXX Here especially is tricky. Old code didn't follow the color cycle. + # If new code does, then we won't know the default non-mapped color out here. + # But also I think in general that logic should move to the outer functions. + # XXX Wait how does this work with a custom palette? + # XXX Regardless of implementation, I think we should change this default + # name to "auto" or something similar that doesn't overlap with a real color name if edgecolor == "gray": - edgecolor = plotter.gray - kwargs.update(dict(s=size ** 2, - edgecolor=edgecolor, - linewidth=linewidth)) + edgecolor = p._get_gray("C0" if color is None else color) + + kwargs.update(dict( + s=size ** 2, + edgecolor=edgecolor, + linewidth=linewidth) + ) + + p.plot_strips( + jitter=jitter, + dodge=dodge, + color=color, + plot_kws=kwargs, + ) + + # XXX this happens inside a plotting method in the distribution plots + # but maybe it's better out here? Alternatively, we have an open issue + # suggesting that _attach could add default axes labels, which seems smart. + p._add_axis_labels(ax) + p._adjust_cat_axis(ax, axis=p.cat_axis) - plotter.plot(ax, kwargs) return ax @@ -2877,108 +3108,7 @@ def stripplot( Examples -------- - Draw a single horizontal strip plot: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns - >>> sns.set_theme(style="whitegrid") - >>> tips = sns.load_dataset("tips") - >>> ax = sns.stripplot(x=tips["total_bill"]) - - Group the strips by a categorical variable: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="day", y="total_bill", data=tips) - - Use a smaller amount of jitter: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="day", y="total_bill", data=tips, jitter=0.05) - - Draw horizontal strips: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="total_bill", y="day", data=tips) - - Draw outlines around the points: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="total_bill", y="day", data=tips, - ... linewidth=1) - - Nest the strips within a second categorical variable: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="sex", y="total_bill", hue="day", data=tips) - - Draw each level of the ``hue`` variable at different locations on the - major categorical axis: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="day", y="total_bill", hue="smoker", - ... data=tips, palette="Set2", dodge=True) - - Control strip order by passing an explicit order: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="time", y="tip", data=tips, - ... order=["Dinner", "Lunch"]) - - Draw strips with large points and different aesthetics: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="day", y="total_bill", hue="smoker", - ... data=tips, palette="Set2", size=20, marker="D", - ... edgecolor="gray", alpha=.25) - - Draw strips of observations on top of a box plot: - - .. plot:: - :context: close-figs - - >>> import numpy as np - >>> ax = sns.boxplot(x="tip", y="day", data=tips, whis=np.inf) - >>> ax = sns.stripplot(x="tip", y="day", data=tips, color=".3") - - Draw strips of observations on top of a violin plot: - - .. plot:: - :context: close-figs - - >>> ax = sns.violinplot(x="day", y="total_bill", data=tips, - ... inner=None, color=".8") - >>> ax = sns.stripplot(x="day", y="total_bill", data=tips) - - Use :func:`catplot` to combine a :func:`stripplot` and a - :class:`FacetGrid`. This allows grouping within additional categorical - variables. Using :func:`catplot` is safer than using :class:`FacetGrid` - directly, as it ensures synchronization of variable order across facets: - - .. plot:: - :context: close-figs - - >>> g = sns.catplot(x="sex", y="total_bill", - ... hue="smoker", col="time", - ... data=tips, kind="strip", - ... height=4, aspect=.7); + .. include:: ../docstrings/stripplot.rst """).format(**_categorical_docs) @@ -3737,6 +3867,7 @@ def catplot( orient=None, color=None, palette=None, legend=True, legend_out=True, sharex=True, sharey=True, margin_titles=False, facet_kws=None, + hue_norm=None, fixed_scale=True, formatter=None, **kwargs ): @@ -3754,6 +3885,108 @@ def catplot( err = "Plot kind '{}' is not recognized".format(kind) raise ValueError(err) + # Check for attempt to plot onto specific axes and warn + if "ax" in kwargs: + msg = ("catplot is a figure-level function and does not accept " + f"target axes. You may wish to try {kind}plot") + warnings.warn(msg, UserWarning) + kwargs.pop("ax") + + if kind == "strip": # XXX gradually incorporate the refactored functions + + p = _CategoricalFacetPlotter( + data=data, + variables=_CategoricalFacetPlotter.get_semantics(locals()), + order=order, + orient=orient, + require_numeric=False, + fixed_scale=fixed_scale, + ) + + # XXX Copying a fair amount from displot, which is not ideal + + for var in ["row", "col"]: + # Handle faceting variables that lack name information + if var in p.variables and p.variables[var] is None: + p.variables[var] = f"_{var}_" + + # Adapt the plot_data dataframe for use with FacetGrid + data = p.plot_data.rename(columns=p.variables) + data = data.loc[:, ~data.columns.duplicated()] + + col_name = p.variables.get("col", None) + row_name = p.variables.get("row", None) + + if facet_kws is None: + facet_kws = {} + + g = FacetGrid( + data=data, row=row_name, col=col_name, + col_wrap=col_wrap, row_order=row_order, + col_order=col_order, height=height, + sharex=sharex, sharey=sharey, + aspect=aspect, + **facet_kws, + ) + + if fixed_scale or p.var_types[p.cat_axis] == "categorical": + p.scale_categorical(p.cat_axis, order=order, formatter=formatter) + + p._attach(g) + + if not p.has_xy_data: + return g + + palette, hue_order = p._hue_backcompat(color, palette, hue_order) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + if kind == "strip": + + # TODO get these defaults programatically? + jitter = kwargs.pop("jitter", True) + dodge = kwargs.pop("dodge", False) + edgecolor = kwargs.pop("edgecolor", "gray") + + strip_kws = kwargs.copy() + + # XXX Copying possibly bad default decisions from original code for now + strip_kws.setdefault("zorder", 3) + strip_kws.setdefault("s", 25) + + if edgecolor == "gray": + edgecolor = p._get_gray("C0" if color is None else color) + strip_kws["edgecolor"] = edgecolor + + strip_kws.setdefault("linewidth", 0) + + p.plot_strips( + jitter=jitter, + dodge=dodge, + color=color, + plot_kws=strip_kws, + ) + + # XXX best way to do this housekeeping? + for ax in g.axes.flat: + p._adjust_cat_axis(ax, axis=p.cat_axis) + + g.set_axis_labels( + p.variables.get("x", None), + p.variables.get("y", None), + ) + g.set_titles() + g.tight_layout() + + # XXX Hack to get the legend data in the right place + for ax in g.axes.flat: + g._update_legend_data(ax) + ax.legend_ = None + + if legend and (hue is not None) and (hue not in [x, row, col]): + g.add_legend(title=hue, label_order=hue_order) + + return g + # Alias the input variables to determine categorical order and palette # correctly in the case of a count plot if kind == "count": @@ -3766,13 +3999,6 @@ def catplot( else: x_, y_ = x, y - # Check for attempt to plot onto specific axes and warn - if "ax" in kwargs: - msg = ("catplot is a figure-level function and does not accept " - "target axes. You may wish to try {}".format(kind + "plot")) - warnings.warn(msg, UserWarning) - kwargs.pop("ax") - # Determine the order for the whole dataset, which will be used in all # facets to ensure representation of all data in the final plot plotter_class = { @@ -3781,7 +4007,6 @@ def catplot( "boxen": _LVPlotter, "bar": _BarPlotter, "point": _PointPlotter, - "strip": _StripPlotter, "swarm": _SwarmPlotter, "count": _CountPlotter, }[kind] @@ -3812,8 +4037,17 @@ def catplot( # so we need to define ``palette`` to get default behavior for the # categorical functions p.establish_colors(color, palette, 1) - if kind != "point" or hue is not None: - palette = p.colors + if ( + (kind != "point" or hue is not None) + # XXX changing this to temporarily support bad sharex=False behavior where + # cat variables could take different colors, which we already warned + # about "breaking" (aka fixing) in the future + and ((sharex and p.orient == "v") or (sharey and p.orient == "h")) + ): + if p.hue_names is None: + palette = dict(zip(p.group_names, p.colors)) + else: + palette = dict(zip(p.hue_names, p.colors)) # Determine keyword arguments for the facets facet_kws = {} if facet_kws is None else facet_kws diff --git a/seaborn/conftest.py b/seaborn/conftest.py index e14febaeb9..335d673b7d 100644 --- a/seaborn/conftest.py +++ b/seaborn/conftest.py @@ -156,7 +156,8 @@ def long_df(rng): a=rng.choice(list("abc"), n), b=rng.choice(list("mnop"), n), c=rng.choice([0, 1], n, [.3, .7]), - t=rng.choice(np.arange("2004-07-30", "2007-07-30", dtype="datetime64[Y]"), n), + d=rng.choice(np.arange("2004-07-30", "2007-07-30", dtype="datetime64[Y]"), n), + t=rng.choice(np.arange("2004-07-30", "2004-07-31", dtype="datetime64[m]"), n), s=rng.choice([2, 4, 8], n), f=rng.choice([0.2, 0.3], n), )) diff --git a/seaborn/distributions.py b/seaborn/distributions.py index 23772d5168..c23bde651e 100644 --- a/seaborn/distributions.py +++ b/seaborn/distributions.py @@ -2159,7 +2159,7 @@ def displot( if "ax" in kwargs: msg = ( "`displot` is a figure-level function and does not accept " - "the ax= paramter. You may wish to try {}plot.".format(kind) + "the ax= parameter. You may wish to try {}plot.".format(kind) ) warnings.warn(msg, UserWarning) kwargs.pop("ax") diff --git a/seaborn/tests/test_categorical.py b/seaborn/tests/test_categorical.py index 30b6029868..871d5947fd 100644 --- a/seaborn/tests/test_categorical.py +++ b/seaborn/tests/test_categorical.py @@ -1,17 +1,88 @@ +import itertools import numpy as np import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt -from matplotlib.colors import rgb2hex +from matplotlib.colors import rgb2hex, to_rgb, to_rgba import pytest from pytest import approx import numpy.testing as npt from distutils.version import LooseVersion +from numpy.testing import ( + assert_array_equal, + assert_array_almost_equal, +) from .. import categorical as cat from .. import palettes -from ..utils import _normal_quantile_func + +from .._core import categorical_order +from ..categorical import ( + _CategoricalPlotterNew, + catplot, + stripplot, +) +from ..palettes import color_palette +from ..utils import _normal_quantile_func, _draw_figure +from .._testing import assert_plots_equal + + +PLOT_FUNCS = [ + catplot, + stripplot, +] + + +class TestCategoricalPlotterNew: + + @pytest.mark.parametrize( + "func,kwargs", + itertools.product( + PLOT_FUNCS, + [ + {"x": "x", "y": "a"}, + {"x": "a", "y": "y"}, + {"x": "y"}, + {"y": "x"}, + ], + ), + ) + def test_axis_labels(self, long_df, func, kwargs): + + func(data=long_df, **kwargs) + + ax = plt.gca() + for axis in "xy": + val = kwargs.get(axis, "") + label_func = getattr(ax, f"get_{axis}label") + assert label_func() == val + + @pytest.mark.parametrize("func", PLOT_FUNCS) + def test_empty(self, func): + + func() + ax = plt.gca() + assert not ax.collections + assert not ax.patches + assert not ax.lines + + def test_redundant_hue_backcompat(self, long_df): + + p = _CategoricalPlotterNew( + data=long_df, + variables={"x": "s", "y": "y"}, + ) + + color = None + palette = dict(zip(long_df["s"].unique(), color_palette())) + hue_order = None + + palette, _ = p._hue_backcompat(color, palette, hue_order, force_hue=True) + + assert p.variables["hue"] == "s" + assert_array_equal(p.plot_data["hue"], p.plot_data["x"]) + assert all(isinstance(k, str) for k in palette) class CategoricalFixture: @@ -1562,150 +1633,464 @@ def test_scatterplot_legend(self): assert tuple(rgb) == tuple(deep_colors[i]) -class TestStripPlotter(CategoricalFixture): +class TestStripPlot: - def test_stripplot_vertical(self): + @pytest.mark.parametrize( + "orient,data_type", + itertools.product(["h", "v"], ["dataframe", "dict"]), + ) + def test_wide(self, wide_df, orient, data_type): - pal = palettes.color_palette() + if data_type == "dict": + wide_df = {k: v.to_numpy() for k, v in wide_df.items()} - ax = cat.stripplot(x="g", y="y", jitter=False, data=self.df) - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + ax = stripplot(data=wide_df, orient=orient, jitter=False) + _draw_figure(ax.figure) + palette = color_palette() - x, y = ax.collections[i].get_offsets().T + cat_idx = 0 if orient == "v" else 1 + val_idx = int(not cat_idx) - npt.assert_array_equal(x, np.ones(len(x)) * i) - npt.assert_array_equal(y, vals) + axis_objs = ax.xaxis, ax.yaxis + cat_axis = axis_objs[cat_idx] - npt.assert_equal(ax.collections[i].get_facecolors()[0, :3], pal[i]) + for i, label in enumerate(cat_axis.get_majorticklabels()): + key = label.get_text() + points = ax.collections[i] + point_pos = points.get_offsets().T + val_pos = point_pos[val_idx] + cat_pos = point_pos[cat_idx] - def test_stripplot_horiztonal(self): + assert (cat_pos == i).all() + assert_array_equal(val_pos, wide_df[key]) - df = self.df.copy() - df.g = df.g.astype("category") + for point_color in points.get_facecolors(): + assert tuple(point_color) == to_rgba(palette[i]) - ax = cat.stripplot(x="y", y="g", jitter=False, data=df) - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + @pytest.mark.parametrize("orient", ["h", "v"]) + def test_flat(self, flat_series, orient): - x, y = ax.collections[i].get_offsets().T + ax = stripplot(data=flat_series, orient=orient, jitter=False) + _draw_figure(ax.figure) - npt.assert_array_equal(x, vals) - npt.assert_array_equal(y, np.ones(len(x)) * i) + cat_idx = 0 if orient == "v" else 1 + val_idx = int(not cat_idx) - def test_stripplot_jitter(self): + axis_objs = ax.xaxis, ax.yaxis + cat_axis = axis_objs[cat_idx] - pal = palettes.color_palette() + for i, label in enumerate(cat_axis.get_majorticklabels()): - ax = cat.stripplot(x="g", y="y", data=self.df, jitter=True) - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + points = ax.collections[i] + point_pos = points.get_offsets().T + val_pos = point_pos[val_idx] + cat_pos = point_pos[cat_idx] - x, y = ax.collections[i].get_offsets().T + assert (cat_pos == i).all() - npt.assert_array_less(np.ones(len(x)) * i - .1, x) - npt.assert_array_less(x, np.ones(len(x)) * i + .1) - npt.assert_array_equal(y, vals) + key = int(label.get_text()) # because fixture has integer index + assert_array_equal(val_pos, flat_series[key]) - npt.assert_equal(ax.collections[i].get_facecolors()[0, :3], pal[i]) + @pytest.mark.parametrize( + "variables,orient", + [ + # Order matters for assigning to x/y + ({"cat": "a", "val": "y", "hue": None}, None), + ({"val": "y", "cat": "a", "hue": None}, None), + ({"cat": "a", "val": "y", "hue": "a"}, None), + ({"val": "y", "cat": "a", "hue": "a"}, None), + ({"cat": "a", "val": "y", "hue": "b"}, None), + ({"val": "y", "cat": "a", "hue": "x"}, None), + ({"cat": "s", "val": "y", "hue": None}, None), + ({"val": "y", "cat": "s", "hue": None}, "h"), + ({"cat": "a", "val": "b", "hue": None}, None), + ({"val": "a", "cat": "b", "hue": None}, "h"), + ({"cat": "a", "val": "t", "hue": None}, None), + ({"val": "t", "cat": "a", "hue": None}, None), + ({"cat": "d", "val": "y", "hue": None}, None), + ({"val": "y", "cat": "d", "hue": None}, None), + ({"cat": "a_cat", "val": "y", "hue": None}, None), + ({"val": "y", "cat": "s_cat", "hue": None}, None), + ], + ) + def test_positions(self, long_df, variables, orient): - def test_dodge_nested_stripplot_vertical(self): + cat_var = variables["cat"] + val_var = variables["val"] + hue_var = variables["hue"] + var_names = list(variables.values()) + x_var, y_var, *_ = var_names - pal = palettes.color_palette() + ax = stripplot( + data=long_df, x=x_var, y=y_var, hue=hue_var, + orient=orient, jitter=False, + ) - ax = cat.stripplot(x="g", y="y", hue="h", data=self.df, - jitter=False, dodge=True) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): - for j, (_, vals) in enumerate(group_vals.groupby(self.h)): + _draw_figure(ax.figure) + cat_idx = var_names.index(cat_var) + val_idx = var_names.index(val_var) - x, y = ax.collections[i * 2 + j].get_offsets().T + axis_objs = ax.xaxis, ax.yaxis + cat_axis = axis_objs[cat_idx] + val_axis = axis_objs[val_idx] - npt.assert_array_equal(x, np.ones(len(x)) * i + [-.2, .2][j]) - npt.assert_array_equal(y, vals) + cat_data = long_df[cat_var] + cat_levels = categorical_order(cat_data) - fc = ax.collections[i * 2 + j].get_facecolors()[0, :3] - assert tuple(fc) == pal[j] + for i, label in enumerate(cat_levels): - def test_dodge_nested_stripplot_horizontal(self): + vals = long_df.loc[cat_data == label, val_var] - df = self.df.copy() - df.g = df.g.astype("category") + points = ax.collections[i].get_offsets().T + cat_points = points[var_names.index(cat_var)] + val_points = points[var_names.index(val_var)] - ax = cat.stripplot(x="y", y="g", hue="h", data=df, - jitter=False, dodge=True) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): - for j, (_, vals) in enumerate(group_vals.groupby(self.h)): + assert_array_equal(val_points, val_axis.convert_units(vals)) + assert_array_equal(cat_points, np.full(len(cat_points), i)) - x, y = ax.collections[i * 2 + j].get_offsets().T + label = pd.Index([label]).astype(str)[0] + assert cat_axis.get_majorticklabels()[i].get_text() == label - npt.assert_array_equal(x, vals) - npt.assert_array_equal(y, np.ones(len(x)) * i + [-.2, .2][j]) + @pytest.mark.parametrize( + "variables", + [ + # Order matters for assigning to x/y + {"cat": "a", "val": "y", "hue": "b"}, + {"val": "y", "cat": "a", "hue": "c"}, + {"cat": "a", "val": "y", "hue": "f"}, + ], + ) + def test_positions_dodged(self, long_df, variables): - def test_nested_stripplot_vertical(self): + cat_var = variables["cat"] + val_var = variables["val"] + hue_var = variables["hue"] + var_names = list(variables.values()) + x_var, y_var, *_ = var_names - # Test a simple vertical strip plot - ax = cat.stripplot(x="g", y="y", hue="h", data=self.df, - jitter=False, dodge=False) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): + ax = stripplot( + data=long_df, x=x_var, y=y_var, hue=hue_var, dodge=True, jitter=False, + ) - x, y = ax.collections[i].get_offsets().T + cat_vals = categorical_order(long_df[cat_var]) + hue_vals = categorical_order(long_df[hue_var]) - npt.assert_array_equal(x, np.ones(len(x)) * i) - npt.assert_array_equal(y, group_vals) + n_hue = len(hue_vals) + offsets = np.linspace(0, .8, n_hue + 1)[:-1] + offsets -= offsets.mean() - def test_nested_stripplot_horizontal(self): + for i, cat_val in enumerate(cat_vals): + for j, hue_val in enumerate(hue_vals): + rows = (long_df[cat_var] == cat_val) & (long_df[hue_var] == hue_val) + vals = long_df.loc[rows, val_var] - df = self.df.copy() - df.g = df.g.astype("category") + points = ax.collections[n_hue * i + j].get_offsets().T + cat_points = points[var_names.index(cat_var)] + val_points = points[var_names.index(val_var)] - ax = cat.stripplot(x="y", y="g", hue="h", data=df, - jitter=False, dodge=False) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): + if pd.api.types.is_datetime64_any_dtype(vals): + vals = mpl.dates.date2num(vals) - x, y = ax.collections[i].get_offsets().T + assert_array_equal(val_points, vals) + + expected = np.full(len(cat_points), i + offsets[j]) + assert_array_almost_equal(cat_points, expected) + + @pytest.mark.parametrize("cat_var", ["a", "s", "d"]) + def test_positions_unfixed(self, long_df, cat_var): + + long_df = long_df.sort_values(cat_var) + ax = stripplot(data=long_df, x=cat_var, y="y", jitter=False, fixed_scale=False) + + for i, (cat_level, cat_data) in enumerate(long_df.groupby(cat_var)): + + points = ax.collections[i].get_offsets().T + cat_points = points[0] + val_points = points[1] + + comp_level = ax.xaxis.convert_units(cat_level) + + assert_array_equal(cat_points, np.full_like(cat_points, comp_level)) + assert_array_equal(val_points, cat_data["y"]) + + def test_jitter_unfixed(self, long_df): + + ax1, ax2 = plt.figure().subplots(2) + kws = dict(data=long_df, x="y", orient="h", fixed_scale=False) + + np.random.seed(0) + stripplot(**kws, y="s", ax=ax1) + + np.random.seed(0) + stripplot(**kws, y=long_df["s"] * 2, ax=ax2) + + p1 = ax1.collections[0].get_offsets()[1] + p2 = ax2.collections[0].get_offsets()[1] + + assert p2.std() > p1.std() + + @pytest.mark.parametrize( + "x_type,order", + [ + (str, None), + (str, ["a", "b", "c"]), + (str, ["c", "a"]), + (str, ["a", "b", "c", "d"]), + (int, None), + (int, [3, 1, 2]), + (int, [3, 1]), + (int, [1, 2, 3, 4]), + (int, ["3", "1", "2"]), + ] + ) + def test_order(self, x_type, order): + + if x_type is str: + x = ["b", "a", "c"] + else: + x = [2, 1, 3] + y = [1, 2, 3] + + ax = stripplot(x=x, y=y, order=order) + _draw_figure(ax.figure) + + if order is None: + order = x + if x_type is int: + order = np.sort(order) + + assert len(ax.collections) == len(order) + tick_labels = ax.xaxis.get_majorticklabels() + + assert ax.get_xlim()[1] == (len(order) - .5) + + for i, points in enumerate(ax.collections): + cat = order[i] + assert tick_labels[i].get_text() == str(cat) + positions = points.get_offsets() + if x_type(cat) in x: + val = y[x.index(x_type(cat))] + assert positions[0, 1] == val + else: + assert not positions.size + + @pytest.mark.parametrize( + "orient,jitter", + itertools.product(["v", "h"], [True, .1]), + ) + def test_jitter(self, long_df, orient, jitter): - npt.assert_array_equal(x, group_vals) - npt.assert_array_equal(y, np.ones(len(x)) * i) + cat_var, val_var = "a", "y" + if orient == "v": + x_var, y_var = cat_var, val_var + cat_idx, val_idx = 0, 1 + else: + x_var, y_var = val_var, cat_var + cat_idx, val_idx = 1, 0 + + cat_vals = categorical_order(long_df[cat_var]) + + ax = stripplot( + data=long_df, x=x_var, y=y_var, jitter=jitter, + ) + + if jitter is True: + jitter_range = .4 + else: + jitter_range = 2 * jitter + + for i, level in enumerate(cat_vals): + + vals = long_df.loc[long_df[cat_var] == level, val_var] + points = ax.collections[i].get_offsets().T + cat_points = points[cat_idx] + val_points = points[val_idx] + + assert_array_equal(val_points, vals) + assert np.std(cat_points) > 0 + assert np.ptp(cat_points) <= jitter_range + + @pytest.mark.parametrize("color", [None, "C1"]) + def test_color(self, long_df, color): + + ax = stripplot(data=long_df, x="a", y="y", color=color) + + expected = to_rgba("C0" if color is None else color) + for points in ax.collections: + for face_color in points.get_facecolors(): + assert tuple(face_color) == expected + + @pytest.mark.parametrize("hue_var", ["a", "b"]) + def test_hue(self, long_df, hue_var): + + cat_var = "b" + + hue_levels = categorical_order(long_df[hue_var]) + cat_levels = categorical_order(long_df[cat_var]) + + pal_name = "muted" + palette = dict(zip(hue_levels, color_palette(pal_name))) + ax = stripplot(data=long_df, x=cat_var, y="y", hue=hue_var, palette=pal_name) + + for i, level in enumerate(cat_levels): + + sub_df = long_df[long_df[cat_var] == level] + point_hues = sub_df[hue_var] + + points = ax.collections[i] + point_colors = points.get_facecolors() + + assert len(point_hues) == len(point_colors) + + for hue, color in zip(point_hues, point_colors): + assert tuple(color) == to_rgba(palette[hue]) + + @pytest.mark.parametrize("hue_var", ["a", "b"]) + def test_hue_dodged(self, long_df, hue_var): + + ax = stripplot(data=long_df, x="y", y="a", hue=hue_var, dodge=True) + colors = color_palette(n_colors=long_df[hue_var].nunique()) + collections = iter(ax.collections) + + # Slightly awkward logic to handle challenges of how the artists work. + # e.g. there are empty scatter collections but the because facecolors + # for the empty collections will return the default scatter color + while colors: + points = next(collections) + if points.get_offsets().any(): + face_color = tuple(points.get_facecolors()[0]) + expected_color = to_rgba(colors.pop(0)) + assert face_color == expected_color + + @pytest.mark.parametrize( + "val_var,val_col,hue_col", + itertools.product(["x", "y"], ["b", "y", "t"], [None, "a"]), + ) + def test_single(self, long_df, val_var, val_col, hue_col): + + var_kws = {val_var: val_col, "hue": hue_col} + ax = stripplot(data=long_df, **var_kws, jitter=False) + _draw_figure(ax.figure) + + axis_vars = ["x", "y"] + val_idx = axis_vars.index(val_var) + cat_idx = int(not val_idx) + cat_var = axis_vars[cat_idx] + + cat_axis = getattr(ax, f"{cat_var}axis") + val_axis = getattr(ax, f"{val_var}axis") + + points = ax.collections[0] + point_pos = points.get_offsets().T + cat_pos = point_pos[cat_idx] + val_pos = point_pos[val_idx] + + assert (cat_pos == 0).all() + num_vals = val_axis.convert_units(long_df[val_col]) + assert_array_equal(val_pos, num_vals) + + if hue_col is not None: + palette = dict(zip( + categorical_order(long_df[hue_col]), color_palette() + )) + + facecolors = points.get_facecolors() + for i, color in enumerate(facecolors): + if hue_col is None: + assert tuple(color) == to_rgba("C0") + else: + hue_level = long_df.loc[i, hue_col] + expected_color = palette[hue_level] + assert tuple(color) == to_rgba(expected_color) + + ticklabels = cat_axis.get_majorticklabels() + assert len(ticklabels) == 1 + assert not ticklabels[0].get_text() + + def test_attributes(self, long_df): + + kwargs = dict( + size=2, + linewidth=1, + edgecolor="C2", + ) + + ax = stripplot(x=long_df["y"], **kwargs) + points, = ax.collections + + assert points.get_sizes().item() == kwargs["size"] ** 2 + assert points.get_linewidths().item() == kwargs["linewidth"] + assert tuple(points.get_edgecolors().squeeze()) == to_rgba(kwargs["edgecolor"]) def test_three_strip_points(self): x = np.arange(3) - ax = cat.stripplot(x=x) - facecolors = ax.collections[0].get_facecolor() - assert facecolors.shape == (3, 4) - npt.assert_array_equal(facecolors[0], facecolors[1]) + ax = stripplot(x=x) + for point_color in ax.collections[0].get_facecolor(): + assert tuple(point_color) == to_rgba("C0") - def test_unaligned_index(self): + def test_log_scale(self): - f, (ax1, ax2) = plt.subplots(2) - cat.stripplot(x=self.g, y=self.y, ax=ax1) - cat.stripplot(x=self.g, y=self.y_perm, ax=ax2) - for p1, p2 in zip(ax1.collections, ax2.collections): - y1, y2 = p1.get_offsets()[:, 1], p2.get_offsets()[:, 1] - assert np.array_equal(np.sort(y1), np.sort(y2)) - assert np.array_equal(p1.get_facecolors()[np.argsort(y1)], - p2.get_facecolors()[np.argsort(y2)]) + x = [1, 10, 100, 1000] - f, (ax1, ax2) = plt.subplots(2) - hue_order = self.h.unique() - cat.stripplot(x=self.g, y=self.y, hue=self.h, - hue_order=hue_order, ax=ax1) - cat.stripplot(x=self.g, y=self.y_perm, hue=self.h, - hue_order=hue_order, ax=ax2) - for p1, p2 in zip(ax1.collections, ax2.collections): - y1, y2 = p1.get_offsets()[:, 1], p2.get_offsets()[:, 1] - assert np.array_equal(np.sort(y1), np.sort(y2)) - assert np.array_equal(p1.get_facecolors()[np.argsort(y1)], - p2.get_facecolors()[np.argsort(y2)]) + ax = plt.figure().subplots() + ax.set_xscale("log") + stripplot(x=x) + vals = ax.collections[0].get_offsets()[:, 0] + assert_array_equal(x, vals) - f, (ax1, ax2) = plt.subplots(2) - hue_order = self.h.unique() - cat.stripplot(x=self.g, y=self.y, hue=self.h, - dodge=True, hue_order=hue_order, ax=ax1) - cat.stripplot(x=self.g, y=self.y_perm, hue=self.h, - dodge=True, hue_order=hue_order, ax=ax2) - for p1, p2 in zip(ax1.collections, ax2.collections): - y1, y2 = p1.get_offsets()[:, 1], p2.get_offsets()[:, 1] - assert np.array_equal(np.sort(y1), np.sort(y2)) - assert np.array_equal(p1.get_facecolors()[np.argsort(y1)], - p2.get_facecolors()[np.argsort(y2)]) + y = [1, 2, 3, 4] + + ax = plt.figure().subplots() + ax.set_xscale("log") + stripplot(x=x, y=y, fixed_scale=False) + for i, point in enumerate(ax.collections): + val = point.get_offsets()[0, 0] + assert val == x[i] + + def test_palette_from_color_deprecation(self, long_df): + + color = (.9, .4, .5) + hex_color = mpl.colors.to_hex(color) + + hue_var = "a" + n_hue = long_df[hue_var].nunique() + palette = color_palette(f"dark:{hex_color}", n_hue) + + with pytest.warns(FutureWarning, match="Setting a gradient palette"): + ax = stripplot(data=long_df, x="z", hue=hue_var, color=color) + + points = ax.collections[0] + for point_color in points.get_facecolors(): + assert to_rgb(point_color) in palette + + @pytest.mark.parametrize( + "kwargs", + [ + dict(data="wide"), + dict(data="wide", orient="h"), + dict(data="long", x="x", color="C3"), + dict(data="long", y="y", hue="a", jitter=False), + # TODO XXX full numeric hue legend crashes pinned mpl, disabling for now + # dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5), + # dict(data="long", x="a_cat", y="y", hue="z"), + dict(data="long", x="y", y="s", hue="c", orient="h", dodge=True), + dict(data="long", x="s", y="y", hue="c", fixed_scale=False), + ] + ) + def test_vs_catplot(self, long_df, wide_df, kwargs): + + if kwargs["data"] == "long": + kwargs["data"] = long_df + elif kwargs["data"] == "wide": + kwargs["data"] = wide_df + + np.random.seed(0) # for jitter + ax = stripplot(**kwargs) + np.random.seed(0) + g = catplot(**kwargs) + + assert_plots_equal(ax, g.ax) class TestSwarmPlotter(CategoricalFixture): @@ -2580,7 +2965,7 @@ def test_plot_colors(self): def test_ax_kwarg_removal(self): f, ax = plt.subplots() - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match="catplot is a figure-level"): g = cat.catplot(x="g", y="y", data=self.df, ax=ax) assert len(ax.collections) == 0 assert len(g.ax.collections) > 0 @@ -2605,16 +2990,20 @@ def test_share_xy(self): for ax in g.axes.flat: assert len(ax.collections) == len(self.df.g.unique()) - # Test unsharing works + # Test unsharing workscol with pytest.warns(UserWarning): - g = cat.catplot(x="g", y="y", col="g", data=self.df, sharex=False) + g = cat.catplot( + x="g", y="y", col="g", data=self.df, sharex=False, kind="bar", + ) for ax in g.axes.flat: - assert len(ax.collections) == 1 + assert len(ax.patches) == 1 with pytest.warns(UserWarning): - g = cat.catplot(x="y", y="g", col="g", data=self.df, sharey=False) + g = cat.catplot( + x="y", y="g", col="g", data=self.df, sharey=False, kind="bar", + ) for ax in g.axes.flat: - assert len(ax.collections) == 1 + assert len(ax.patches) == 1 # Make sure no warning is raised if color is provided on unshared plot with pytest.warns(None) as record: @@ -2622,12 +3011,16 @@ def test_share_xy(self): x="g", y="y", col="g", data=self.df, sharex=False, color="b" ) assert not len(record) + for ax in g.axes.flat: + assert ax.get_xlim() == (-.5, .5) with pytest.warns(None) as record: g = cat.catplot( x="y", y="g", col="g", data=self.df, sharey=False, color="r" ) assert not len(record) + for ax in g.axes.flat: + assert ax.get_ylim() == (.5, -.5) # Make sure order is used if given, regardless of sharex value order = self.df.g.unique() @@ -2639,6 +3032,15 @@ def test_share_xy(self): for ax in g.axes.flat: assert len(ax.collections) == len(self.df.g.unique()) + @pytest.mark.parametrize("var", ["col", "row"]) + def test_array_faceter(self, long_df, var): + + g1 = catplot(data=long_df, x="y", **{var: "a"}) + g2 = catplot(data=long_df, x="y", **{var: long_df["a"].to_numpy()}) + + for ax1, ax2 in zip(g1.axes.flat, g2.axes.flat): + assert_plots_equal(ax1, ax2) + class TestBoxenPlotter(CategoricalFixture): diff --git a/seaborn/tests/test_core.py b/seaborn/tests/test_core.py index e5c91acde8..8a70ba7de8 100644 --- a/seaborn/tests/test_core.py +++ b/seaborn/tests/test_core.py @@ -1138,6 +1138,67 @@ def test_attach_facets(self, long_df): assert p.ax is None assert p.facets == g + def test_attach_shared_axes(self, long_df): + + g = FacetGrid(long_df) + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"}) + p._attach(g) + assert p.converters["x"].nunique() == 1 + + g = FacetGrid(long_df, col="a") + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"}) + p._attach(g) + assert p.converters["x"].nunique() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", sharex=False) + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"}) + p._attach(g) + assert p.converters["x"].nunique() == p.plot_data["col"].nunique() + assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", sharex=False, col_wrap=2) + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"}) + p._attach(g) + assert p.converters["x"].nunique() == p.plot_data["col"].nunique() + assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", row="b") + p = VectorPlotter( + data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"}, + ) + p._attach(g) + assert p.converters["x"].nunique() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", row="b", sharex=False) + p = VectorPlotter( + data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"}, + ) + p._attach(g) + assert p.converters["x"].nunique() == len(g.axes.flat) + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", row="b", sharex="col") + p = VectorPlotter( + data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"}, + ) + p._attach(g) + assert p.converters["x"].nunique() == p.plot_data["col"].nunique() + assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", row="b", sharey="row") + p = VectorPlotter( + data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"}, + ) + p._attach(g) + assert p.converters["x"].nunique() == 1 + assert p.converters["y"].nunique() == p.plot_data["row"].nunique() + assert p.converters["y"].groupby(p.plot_data["row"]).nunique().max() == 1 + def test_get_axes_single(self, long_df): ax = plt.figure().subplots() @@ -1220,7 +1281,7 @@ def test_comp_data_category_order(self): ) def comp_data_missing_fixture(self, request): - # This fixture holds the logic for parametrizing + # This fixture holds the logic for parameterizing # the following test (test_comp_data_missing) NA, var_type = request.param @@ -1264,6 +1325,61 @@ def test_var_order(self, long_df): assert p.var_levels[var] == order + def test_scale_native(self, long_df): + + p = VectorPlotter(data=long_df, variables={"x": "x"}) + with pytest.raises(NotImplementedError): + p.scale_native("x") + + def test_scale_numeric(self, long_df): + + p = VectorPlotter(data=long_df, variables={"y": "y"}) + with pytest.raises(NotImplementedError): + p.scale_numeric("y") + + def test_scale_datetime(self, long_df): + + p = VectorPlotter(data=long_df, variables={"x": "t"}) + with pytest.raises(NotImplementedError): + p.scale_datetime("x") + + def test_scale_categorical(self, long_df): + + p = VectorPlotter(data=long_df, variables={"x": "x"}) + p.scale_categorical("y") + assert p.variables["y"] is None + assert p.var_types["y"] == "categorical" + assert (p.plot_data["y"] == "").all() + + p = VectorPlotter(data=long_df, variables={"x": "s"}) + p.scale_categorical("x") + assert p.var_types["x"] == "categorical" + assert hasattr(p.plot_data["x"], "str") + assert not p._var_ordered["x"] + assert p.plot_data["x"].is_monotonic_increasing + assert_array_equal(p.var_levels["x"], p.plot_data["x"].unique()) + + p = VectorPlotter(data=long_df, variables={"x": "a"}) + p.scale_categorical("x") + assert not p._var_ordered["x"] + assert_array_equal(p.var_levels["x"], categorical_order(long_df["a"])) + + p = VectorPlotter(data=long_df, variables={"x": "a_cat"}) + p.scale_categorical("x") + assert p._var_ordered["x"] + assert_array_equal(p.var_levels["x"], categorical_order(long_df["a_cat"])) + + p = VectorPlotter(data=long_df, variables={"x": "a"}) + order = np.roll(long_df["a"].unique(), 1) + p.scale_categorical("x", order=order) + assert p._var_ordered["x"] + assert_array_equal(p.var_levels["x"], order) + + p = VectorPlotter(data=long_df, variables={"x": "s"}) + p.scale_categorical("x", formatter=lambda x: f"{x:%}") + assert p.plot_data["x"].str.endswith("%").all() + assert all(s.endswith("%") for s in p.var_levels["x"]) + class TestCoreFunc: @@ -1327,10 +1443,14 @@ def test_infer_orient(self): nums = pd.Series(np.arange(6)) cats = pd.Series(["a", "b"] * 3) + dates = pd.date_range("1999-09-22", "2006-05-14", 6) assert infer_orient(cats, nums) == "v" assert infer_orient(nums, cats) == "h" + assert infer_orient(cats, dates, require_numeric=False) == "v" + assert infer_orient(dates, cats, require_numeric=False) == "h" + assert infer_orient(nums, None) == "h" with pytest.warns(UserWarning, match="Vertical .+ `x`"): assert infer_orient(nums, None, "v") == "h" @@ -1361,6 +1481,9 @@ def test_infer_orient(self): with pytest.raises(TypeError, match="Neither"): infer_orient(cats, cats) + with pytest.raises(ValueError, match="`orient` must start with"): + infer_orient(cats, nums, orient="bad value") + def test_categorical_order(self): x = ["a", "c", "c", "b", "a", "d"] diff --git a/seaborn/tests/test_distributions.py b/seaborn/tests/test_distributions.py index 493d4ef579..e756d40011 100644 --- a/seaborn/tests/test_distributions.py +++ b/seaborn/tests/test_distributions.py @@ -815,8 +815,8 @@ def test_long_vectors(self, long_df): f, ax2 = plt.subplots() kdeplot(x=x, y=y, ax=ax2) - for c1, c2 in zip(ax1.collections, ax2.collections): - assert_array_equal(c1.get_offsets(), c2.get_offsets()) + for c1, c2 in zip(ax1.collections, ax2.collections): + assert_array_equal(c1.get_offsets(), c2.get_offsets()) def test_singular_data(self): @@ -2040,15 +2040,16 @@ def test_versus_single_ecdfplot(self, long_df, kwargs): ) def test_with_rug(self, long_df, kwargs): - ax = rugplot(data=long_df, **kwargs) + ax = plt.figure().subplots() + histplot(data=long_df, **kwargs, ax=ax) + rugplot(data=long_df, **kwargs, ax=ax) + g = displot(long_df, rug=True, **kwargs) - g.ax.patches = [] assert_plots_equal(ax, g.ax, labels=False) long_df["_"] = "_" g2 = displot(long_df, col="_", rug=True, **kwargs) - g2.ax.patches = [] assert_plots_equal(ax, g2.ax, labels=False) diff --git a/seaborn/utils.py b/seaborn/utils.py index 5f2db69109..0d2d466428 100644 --- a/seaborn/utils.py +++ b/seaborn/utils.py @@ -588,7 +588,7 @@ def _check_argument(param, options, value): """Raise if value for param is not in options.""" if value not in options: raise ValueError( - f"`{param}` must be one of {options}, but {value} was passed.`" + f"`{param}` must be one of {options}, but {repr(value)} was passed." )