Skip to content

Commit

Permalink
Add Convention.select_indexes() and related functions
Browse files Browse the repository at this point in the history
This is a breaking change for plugins as it changes the abstract methods
on the Convention class. `Convention.selector_for_indexes()` is a new
required method, and `Convention.selector_for_index()` now has a default
implementation which calls the new method.
  • Loading branch information
mx-moth committed Jul 10, 2024
1 parent 43fd43d commit 8e16424
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 54 deletions.
152 changes: 140 additions & 12 deletions src/emsarray/conventions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from collections.abc import Callable, Hashable, Iterable, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast

import numpy
import xarray
Expand All @@ -17,7 +17,7 @@
from emsarray import utils
from emsarray.compat.shapely import SpatialIndex
from emsarray.exceptions import NoSuchCoordinateError
from emsarray.operations import depth
from emsarray.operations import depth, point_extraction
from emsarray.plot import (
_requires_plot, animate_on_figure, make_plot_title, plot_on_figure,
polygons_to_collection
Expand Down Expand Up @@ -516,7 +516,7 @@ def get_depth_coordinate_for_data_array(
candidates = [
coordinate
for coordinate in self.depth_coordinates
if coordinate.dims[0] in data_array.dims
if set(coordinate.dims) <= set(data_array.dims)
]
if len(candidates) == 0:
raise NoSuchCoordinateError(f"No depth coordinate found for {name}")
Expand Down Expand Up @@ -1471,8 +1471,7 @@ def get_index_for_point(
polygon=self.polygons[linear_index])
return None

@abc.abstractmethod
def selector_for_index(self, index: Index) -> dict[Hashable, int]:
def selector_for_index(self, index: Index) -> xarray.Dataset:
"""
Convert a convention native index into a selector
that can be passed to :meth:`Dataset.isel <xarray.Dataset.isel>`.
Expand All @@ -1494,11 +1493,24 @@ def selector_for_index(self, index: Index) -> dict[Hashable, int]:
:meth:`.select_point`
:ref:`indexing`
"""
index_dimension = utils.find_unused_dimension(self.dataset, 'index')
dataset = self.selector_for_indexes([index], index_dimension=index_dimension)
dataset = dataset.squeeze(dim=index_dimension, drop=False)
return dataset

@abc.abstractmethod
def selector_for_indexes(
self,
indexes: list[Index],
*,
index_dimension: Hashable | None = None,
) -> xarray.Dataset:
pass

def select_index(
self,
index: Index,
drop_geometry: bool = True,
) -> xarray.Dataset:
"""
Return a new dataset that contains values only from a single index.
Expand All @@ -1516,6 +1528,62 @@ def select_index(
index : :data:`Index`
The index to select.
The index must be for the default grid kind for this dataset.
drop_geometry : bool, default True
Whether to drop geometry variables from the returned point dataset.
If the geometry data is kept
the associated geometry data will no longer conform to the dataset convention
and may not conform to any sensible convention at all.
The format of the geometry data left after selecting points is convention-dependent.
Returns
-------
:class:`xarray.Dataset`
A new dataset that is subset to the one index.
Notes
-----
The returned dataset will most likely not have sufficient coordinate data
to be used with a particular :class:`Convention` any more.
The ``dataset.ems`` accessor will raise an error if accessed on the new dataset.
"""
index_dimension = utils.find_unused_dimension(self.dataset, 'index')
dataset = self.select_indexes([index], index_dimension=index_dimension, drop_geometry=drop_geometry)
dataset = dataset.squeeze(dim=index_dimension, drop=False)
return dataset

def select_indexes(
self,
indexes: list[Index],
*,
index_dimension: Hashable | None = None,
drop_geometry: bool = True,
) -> xarray.Dataset:
"""
Return a new dataset that contains values only at the selected indexes.
This is much like doing a :func:`xarray.Dataset.isel()` on some indexes,
but works with convention native index types.
An index is associated with a grid kind.
The returned dataset will only contain variables that were defined on this grid,
with the single indexed point selected.
For example, if the index of a face is passed in,
the returned dataset will not contain any variables defined on an edge.
Parameters
----------
index : :data:`Index`
The index to select.
The index must be for the default grid kind for this dataset.
index_dimension : str, optional
The name of the new dimension added for each index to select.
Defaults to the :func:`first unused dimension <.utils.find_unused_dimension>` with prefix `index`.
drop_geometry : bool, default True
Whether to drop geometry variables from the returned point dataset.
If the geometry data is kept
the associated geometry data will no longer conform to the dataset convention
and may not conform to any sensible convention at all.
The format of the geometry data left after selecting points is convention-dependent.
Returns
-------
Expand All @@ -1529,16 +1597,21 @@ def select_index(
to be used with a particular :class:`Convention` any more.
The ``dataset.ems`` accessor will raise an error if accessed on the new dataset.
"""
selector = self.selector_for_index(index)
selector = self.selector_for_indexes(indexes, index_dimension=index_dimension)

# Make a new dataset consisting of only data arrays that use at least
# one of these dimensions.
dims = set(selector.keys())
if drop_geometry:
dataset = self.drop_geometry()
else:
dataset = self.dataset

dims = set(selector.variables.keys())
names = [
name for name, data_array in self.dataset.items()
name for name, data_array in dataset.items()
if dims.intersection(data_array.dims)
]
dataset = utils.extract_vars(self.dataset, names)
dataset = utils.extract_vars(dataset, names)

# Select just this point
return dataset.isel(selector)
Expand All @@ -1565,6 +1638,38 @@ def select_point(self, point: Point) -> xarray.Dataset:
raise ValueError("Point did not intersect dataset")
return self.select_index(index.index)

def select_points(
self,
points: list[Point],
*,
point_dimension: Hashable | None = None,
missing_points: Literal['error', 'drop'] = 'error',
) -> xarray.Dataset:
"""
Extract values from all variables on the default grid at a sequence of points.
Parameters
----------
points : list of shapely.Point
The points to extract
point_dimension : str, optional
The name of the new dimension used to index points.
Defaults to 'point', or 'point_0', 'point_1', etc if those dimensions already exist.
missing_points : {'error', 'drop'}, default 'error'
What to do if a point does not intersect the dataset.
'raise' will raise an error, while 'drop' will drop those points.
Returns
-------
xarray.Dataset
A dataset with values extracted from the points.
No variables not defined on the default grid and no geometry variables will be present.
"""
if point_dimension is None:
point_dimension = utils.find_unused_dimension(self.dataset, 'point')
return point_extraction.extract_points(
self.dataset, points, point_dimension=point_dimension, missing_points=missing_points)

@abc.abstractmethod
def get_all_geometry_names(self) -> list[Hashable]:
"""
Expand Down Expand Up @@ -1936,7 +2041,30 @@ def wind(
dimensions=dimensions, sizes=sizes,
linear_dimension=linear_dimension)

def selector_for_index(self, index: Index) -> dict[Hashable, int]:
grid_kind, indexes = self.unpack_index(index)
def selector_for_indexes(
self,
indexes: list[Index],
*,
index_dimension: Hashable | None = None,
) -> xarray.Dataset:
if index_dimension is None:
index_dimension = utils.find_unused_dimension(self.dataset, 'index')
if len(indexes) == 0:
raise ValueError("Need at least one index to select")

grid_kinds, index_tuples = zip(*[self.unpack_index(index) for index in indexes])

unique_grid_kinds = set(grid_kinds)
if len(unique_grid_kinds) > 1:
raise ValueError(
"All indexes must be on the same grid kind, got "
+ ", ".join(map(repr, unique_grid_kinds)))

grid_kind = grid_kinds[0]
dimensions = self.grid_dimensions[grid_kind]
return dict(zip(dimensions, indexes))
# This array will have shape (len(indexes), len(dimensions))
index_array = numpy.array(index_tuples)
return xarray.Dataset({
dimension: (index_dimension, index_array[:, i])
for i, dimension in enumerate(dimensions)
})
27 changes: 13 additions & 14 deletions src/emsarray/operations/point_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import xarray
import xarray.core.dtypes as xrdtypes

import emsarray.conventions
from emsarray import utils
from emsarray.conventions import Convention


@dataclasses.dataclass
Expand Down Expand Up @@ -67,7 +67,7 @@ def extract_points(
dataset: xarray.Dataset,
points: list[shapely.Point],
*,
point_dimension: Hashable = 'point',
point_dimension: Hashable | None = None,
missing_points: Literal['error', 'drop'] = 'error',
) -> xarray.Dataset:
"""
Expand Down Expand Up @@ -106,7 +106,10 @@ def extract_points(
--------
:func:`extract_dataframe`
"""
convention: Convention = dataset.ems
convention: emsarray.conventions.Convention = dataset.ems

if point_dimension is None:
point_dimension = utils.find_unused_dimension(dataset, 'point')

# Find the indexer for each given point
indexes = numpy.array([convention.get_index_for_point(point) for point in points])
Expand All @@ -118,21 +121,17 @@ def extract_points(
indexes=out_of_bounds,
points=[points[i] for i in out_of_bounds])

# Make a DataFrame out of all point indexers
selector = convention.selector_for_indexes([i.index for i in indexes])
selector_df = pandas.DataFrame([
convention.selector_for_index(index.index)
for index in indexes
if index is not None])
point_indexes = [i for i, index in enumerate(indexes) if index is not None]
point_ds = convention.select_indexes(
[index.index for index in indexes if index is not None],
index_dimension=point_dimension,
drop_geometry=True)

# Subset the dataset to the points
point_ds = convention.drop_geometry()
selector_ds = _dataframe_to_dataset(selector_df, dimension_name=point_dimension)
point_ds = point_ds.isel(selector_ds)
# Number the points
point_indexes = [i for i, index in enumerate(indexes) if index is not None]
point_ds = point_ds.assign_coords({
point_dimension: ([point_dimension], point_indexes),
})

return point_ds


Expand Down
16 changes: 14 additions & 2 deletions tests/conventions/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,20 @@ def wind_index(
def ravel_index(self, indexes: SimpleGridIndex) -> int:
return int(numpy.ravel_multi_index((indexes.y, indexes.x), self.shape))

def selector_for_index(self, index: SimpleGridIndex) -> dict[Hashable, int]:
return {'x': index.x, 'y': index.y}
def selector_for_indexes(
self,
indexes: list[SimpleGridIndex],
*,
index_dimension: Hashable | None = None,
) -> xarray.Dataset:
if index_dimension is None:
index_dimension = utils.find_unused_dimension(self.dataset, 'index')
index_array = numpy.array([[index.x, index.y] for index in indexes])

return xarray.Dataset({
'x': (index_dimension, index_array[:, 0]),
'y': (index_dimension, index_array[:, 1]),
})

def ravel(
self,
Expand Down
Loading

0 comments on commit 8e16424

Please sign in to comment.