Skip to content

Commit

Permalink
Move Variable aggregations to NamedArray (#8304)
Browse files Browse the repository at this point in the history
Co-authored-by: Anderson Banihirwe <[email protected]>
Co-authored-by: Anderson Banihirwe <[email protected]>
  • Loading branch information
3 people authored Oct 17, 2023
1 parent f895dc1 commit 88285f9
Show file tree
Hide file tree
Showing 9 changed files with 1,301 additions and 194 deletions.
30 changes: 30 additions & 0 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,36 @@
IndexVariable.sizes
IndexVariable.values


namedarray.core.NamedArray.all
namedarray.core.NamedArray.any
namedarray.core.NamedArray.attrs
namedarray.core.NamedArray.chunks
namedarray.core.NamedArray.chunksizes
namedarray.core.NamedArray.copy
namedarray.core.NamedArray.count
namedarray.core.NamedArray.cumprod
namedarray.core.NamedArray.cumsum
namedarray.core.NamedArray.data
namedarray.core.NamedArray.dims
namedarray.core.NamedArray.dtype
namedarray.core.NamedArray.get_axis_num
namedarray.core.NamedArray.max
namedarray.core.NamedArray.mean
namedarray.core.NamedArray.median
namedarray.core.NamedArray.min
namedarray.core.NamedArray.nbytes
namedarray.core.NamedArray.ndim
namedarray.core.NamedArray.prod
namedarray.core.NamedArray.reduce
namedarray.core.NamedArray.shape
namedarray.core.NamedArray.size
namedarray.core.NamedArray.sizes
namedarray.core.NamedArray.std
namedarray.core.NamedArray.sum
namedarray.core.NamedArray.var


plot.plot
plot.line
plot.step
Expand Down
3 changes: 0 additions & 3 deletions xarray/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.ops import (
IncludeCumMethods,
IncludeNumpySameMethods,
IncludeReduceMethods,
)
Expand Down Expand Up @@ -99,8 +98,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

class VariableArithmetic(
ImplementsArrayReduce,
IncludeReduceMethods,
IncludeCumMethods,
IncludeNumpySameMethods,
SupportsArithmetic,
VariableOpsMixin,
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6280,7 +6280,7 @@ def dropna(
array = self._variables[k]
if dim in array.dims:
dims = [d for d in array.dims if d != dim]
count += np.asarray(array.count(dims)) # type: ignore[attr-defined]
count += np.asarray(array.count(dims))
size += math.prod([self.dims[d] for d in dims])

if thresh is not None:
Expand Down
32 changes: 16 additions & 16 deletions xarray/core/formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _load_static_files():
]


def short_data_repr_html(array):
def short_data_repr_html(array) -> str:
"""Format "data" for DataArray and Variable."""
internal_data = getattr(array, "variable", array)._data
if hasattr(internal_data, "_repr_html_"):
Expand All @@ -37,7 +37,7 @@ def short_data_repr_html(array):
return f"<pre>{text}</pre>"


def format_dims(dims, dims_with_index):
def format_dims(dims, dims_with_index) -> str:
if not dims:
return ""

Expand All @@ -53,7 +53,7 @@ def format_dims(dims, dims_with_index):
return f"<ul class='xr-dim-list'>{dims_li}</ul>"


def summarize_attrs(attrs):
def summarize_attrs(attrs) -> str:
attrs_dl = "".join(
f"<dt><span>{escape(str(k))} :</span></dt>" f"<dd>{escape(str(v))}</dd>"
for k, v in attrs.items()
Expand All @@ -62,7 +62,7 @@ def summarize_attrs(attrs):
return f"<dl class='xr-attrs'>{attrs_dl}</dl>"


def _icon(icon_name):
def _icon(icon_name) -> str:
# icon_name should be defined in xarray/static/html/icon-svg-inline.html
return (
f"<svg class='icon xr-{icon_name}'>"
Expand All @@ -72,7 +72,7 @@ def _icon(icon_name):
)


def summarize_variable(name, var, is_index=False, dtype=None):
def summarize_variable(name, var, is_index=False, dtype=None) -> str:
variable = var.variable if hasattr(var, "variable") else var

cssclass_idx = " class='xr-has-index'" if is_index else ""
Expand Down Expand Up @@ -109,7 +109,7 @@ def summarize_variable(name, var, is_index=False, dtype=None):
)


def summarize_coords(variables):
def summarize_coords(variables) -> str:
li_items = []
for k, v in variables.items():
li_content = summarize_variable(k, v, is_index=k in variables.xindexes)
Expand All @@ -120,7 +120,7 @@ def summarize_coords(variables):
return f"<ul class='xr-var-list'>{vars_li}</ul>"


def summarize_vars(variables):
def summarize_vars(variables) -> str:
vars_li = "".join(
f"<li class='xr-var-item'>{summarize_variable(k, v)}</li>"
for k, v in variables.items()
Expand All @@ -129,14 +129,14 @@ def summarize_vars(variables):
return f"<ul class='xr-var-list'>{vars_li}</ul>"


def short_index_repr_html(index):
def short_index_repr_html(index) -> str:
if hasattr(index, "_repr_html_"):
return index._repr_html_()

return f"<pre>{escape(repr(index))}</pre>"


def summarize_index(coord_names, index):
def summarize_index(coord_names, index) -> str:
name = "<br>".join([escape(str(n)) for n in coord_names])

index_id = f"index-{uuid.uuid4()}"
Expand All @@ -155,7 +155,7 @@ def summarize_index(coord_names, index):
)


def summarize_indexes(indexes):
def summarize_indexes(indexes) -> str:
indexes_li = "".join(
f"<li class='xr-var-item'>{summarize_index(v, i)}</li>"
for v, i in indexes.items()
Expand All @@ -165,7 +165,7 @@ def summarize_indexes(indexes):

def collapsible_section(
name, inline_details="", details="", n_items=None, enabled=True, collapsed=False
):
) -> str:
# "unique" id to expand/collapse the section
data_id = "section-" + str(uuid.uuid4())

Expand All @@ -187,7 +187,7 @@ def collapsible_section(

def _mapping_section(
mapping, name, details_func, max_items_collapse, expand_option_name, enabled=True
):
) -> str:
n_items = len(mapping)
expanded = _get_boolean_with_default(
expand_option_name, n_items < max_items_collapse
Expand All @@ -203,15 +203,15 @@ def _mapping_section(
)


def dim_section(obj):
def dim_section(obj) -> str:
dim_list = format_dims(obj.dims, obj.xindexes.dims)

return collapsible_section(
"Dimensions", inline_details=dim_list, enabled=False, collapsed=True
)


def array_section(obj):
def array_section(obj) -> str:
# "unique" id to expand/collapse the section
data_id = "section-" + str(uuid.uuid4())
collapsed = (
Expand Down Expand Up @@ -296,7 +296,7 @@ def _obj_repr(obj, header_components, sections):
)


def array_repr(arr):
def array_repr(arr) -> str:
dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape))
if hasattr(arr, "xindexes"):
indexed_dims = arr.xindexes.dims
Expand Down Expand Up @@ -326,7 +326,7 @@ def array_repr(arr):
return _obj_repr(arr, header_components, sections)


def dataset_repr(ds):
def dataset_repr(ds) -> str:
obj_type = f"xarray.{type(ds).__name__}"

header_components = [f"<div class='xr-obj-type'>{escape(obj_type)}</div>"]
Expand Down
25 changes: 0 additions & 25 deletions xarray/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
"var",
"median",
]
NAN_CUM_METHODS = ["cumsum", "cumprod"]
# TODO: wrap take, dot, sort


Expand Down Expand Up @@ -263,20 +262,6 @@ def inject_reduce_methods(cls):
setattr(cls, name, func)


def inject_cum_methods(cls):
methods = [(name, getattr(duck_array_ops, name), True) for name in NAN_CUM_METHODS]
for name, f, include_skipna in methods:
numeric_only = getattr(f, "numeric_only", False)
func = cls._reduce_method(f, include_skipna, numeric_only)
func.__name__ = name
func.__doc__ = _CUM_DOCSTRING_TEMPLATE.format(
name=name,
cls=cls.__name__,
extra_args=cls._cum_extra_args_docstring.format(name=name),
)
setattr(cls, name, func)


def op_str(name):
return f"__{name}__"

Expand Down Expand Up @@ -316,16 +301,6 @@ def __init_subclass__(cls, **kwargs):
inject_reduce_methods(cls)


class IncludeCumMethods:
__slots__ = ()

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)

if getattr(cls, "_reduce_method", None):
inject_cum_methods(cls)


class IncludeNumpySameMethods:
__slots__ = ()

Expand Down
68 changes: 15 additions & 53 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import numbers
import warnings
from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Hashable, Mapping, Sequence
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast
Expand Down Expand Up @@ -1704,7 +1704,7 @@ def clip(self, min=None, max=None):

return apply_ufunc(np.clip, self, min, max, dask="allowed")

def reduce(
def reduce( # type: ignore[override]
self,
func: Callable[..., Any],
dim: Dims = None,
Expand Down Expand Up @@ -1745,59 +1745,21 @@ def reduce(
Array with summarized data and the indicated dimension(s)
removed.
"""
if dim == ...:
dim = None
if dim is not None and axis is not None:
raise ValueError("cannot supply both 'axis' and 'dim' arguments")

if dim is not None:
axis = self.get_axis_num(dim)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", r"Mean of empty slice", category=RuntimeWarning
)
if axis is not None:
if isinstance(axis, tuple) and len(axis) == 1:
# unpack axis for the benefit of functions
# like np.argmin which can't handle tuple arguments
axis = axis[0]
data = func(self.data, axis=axis, **kwargs)
else:
data = func(self.data, **kwargs)

if getattr(data, "shape", ()) == self.shape:
dims = self.dims
else:
removed_axes: Iterable[int]
if axis is None:
removed_axes = range(self.ndim)
else:
removed_axes = np.atleast_1d(axis) % self.ndim
if keepdims:
# Insert np.newaxis for removed dims
slices = tuple(
np.newaxis if i in removed_axes else slice(None, None)
for i in range(self.ndim)
)
if getattr(data, "shape", None) is None:
# Reduce has produced a scalar value, not an array-like
data = np.asanyarray(data)[slices]
else:
data = data[slices]
dims = self.dims
else:
dims = tuple(
adim for n, adim in enumerate(self.dims) if n not in removed_axes
)
keep_attrs_ = (
_get_keep_attrs(default=False) if keep_attrs is None else keep_attrs
)

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
attrs = self._attrs if keep_attrs else None
# Noe that the call order for Variable.mean is
# Variable.mean -> NamedArray.mean -> Variable.reduce
# -> NamedArray.reduce
result = super().reduce(
func=func, dim=dim, axis=axis, keepdims=keepdims, **kwargs
)

# We need to return `Variable` rather than the type of `self` at the moment, ref
# #8216
return Variable(dims, data, attrs=attrs)
# return Variable always to support IndexVariable
return Variable(
result.dims, result._data, attrs=result._attrs if keep_attrs_ else None
)

@classmethod
def concat(
Expand Down
Loading

0 comments on commit 88285f9

Please sign in to comment.