Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous lazy preprocessor improvements #2520

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions esmvalcore/preprocessor/_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from iris.exceptions import CoordinateNotFoundError

from esmvalcore.preprocessor._shared import (
apply_mask,
get_dims_along_axes,
get_iris_aggregator,
get_normalized_cube,
preserve_float_dtype,
Expand Down Expand Up @@ -188,8 +190,8 @@
cube = cube[..., i_slice, j_slice]
selection = selection[i_slice, j_slice]
# Mask remaining coordinates outside region
mask = da.broadcast_to(~selection, cube.shape)
cube.data = da.ma.masked_where(mask, cube.core_data())
horizontal_dims = get_dims_along_axes(cube, ["X", "Y"])
cube.data = apply_mask(~selection, cube.core_data(), horizontal_dims)
return cube


Expand Down Expand Up @@ -857,8 +859,8 @@
_cube.add_aux_coord(
AuxCoord(id_, units="no_unit", long_name="shape_id")
)
mask = da.broadcast_to(mask, _cube.shape)
_cube.data = da.ma.masked_where(~mask, _cube.core_data())
horizontal_dims = get_dims_along_axes(cube, axes=["X", "Y"])
_cube.data = apply_mask(~mask, _cube.core_data(), horizontal_dims)
cubelist.append(_cube)
result = fix_coordinate_ordering(cubelist.merge_cube())
if cube.cell_measures():
Expand All @@ -869,8 +871,21 @@
# (time, shape_id, depth, lat, lon)
if measure.ndim > 3 and result.ndim > 4:
data = measure.core_data()
data = da.expand_dims(data, axis=(1,))
data = da.broadcast_to(data, result.shape)
dim_map = get_dims_along_axes(result, ["T", "Z", "Y", "X"])
if cube.has_lazy_data():
chunks = cube.lazy_data().chunks
data = da.asarray(

Check warning on line 877 in esmvalcore/preprocessor/_area.py

View check run for this annotation

Codecov / codecov/patch

esmvalcore/preprocessor/_area.py#L876-L877

Added lines #L876 - L877 were not covered by tests
data,
chunks=tuple(chunks[i] for i in dim_map),
)
else:
chunks = None
data = iris.util.broadcast_to_shape(
data,
result.shape,
dim_map=dim_map,
chunks=chunks,
)
measure = iris.coords.CellMeasure(
data,
standard_name=measure.standard_name,
Expand Down
2 changes: 2 additions & 0 deletions esmvalcore/preprocessor/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from esmvalcore.cmor.check import CheckLevels
from esmvalcore.esgf.facets import FACETS
from esmvalcore.iris_helpers import merge_cube_attributes
from esmvalcore.preprocessor._shared import _rechunk_aux_factory_dependencies

from .._task import write_ncl_settings

Expand Down Expand Up @@ -392,6 +393,7 @@ def concatenate(cubes, check_level=CheckLevels.DEFAULT):
cubes = _sort_cubes_by_time(cubes)
_fix_calendars(cubes)
cubes = _check_time_overlaps(cubes)
cubes = [_rechunk_aux_factory_dependencies(cube) for cube in cubes]
result = _concatenate_cubes(cubes, check_level=check_level)

if len(result) == 1:
Expand Down
32 changes: 5 additions & 27 deletions esmvalcore/preprocessor/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

import logging
import os
from collections.abc import Iterable
from typing import Literal, Optional
from typing import Literal

import cartopy.io.shapereader as shpreader
import dask.array as da
Expand All @@ -22,7 +21,7 @@
from iris.cube import Cube
from iris.util import rolling_window

from esmvalcore.preprocessor._shared import get_array_module
from esmvalcore.preprocessor._shared import apply_mask

from ._supplementary_vars import register_supplementaries

Expand Down Expand Up @@ -61,24 +60,6 @@ def _get_fx_mask(
return inmask


def _apply_mask(
mask: np.ndarray | da.Array,
array: np.ndarray | da.Array,
dim_map: Optional[Iterable[int]] = None,
) -> np.ndarray | da.Array:
"""Apply a (broadcasted) mask on an array."""
npx = get_array_module(mask, array)
if dim_map is not None:
if isinstance(array, da.Array):
chunks = array.chunks
else:
chunks = None
mask = iris.util.broadcast_to_shape(
mask, array.shape, dim_map, chunks=chunks
)
return npx.ma.masked_where(mask, array)


@register_supplementaries(
variables=["sftlf", "sftof"],
required="prefer_at_least_one",
Expand Down Expand Up @@ -145,7 +126,7 @@ def mask_landsea(cube: Cube, mask_out: Literal["land", "sea"]) -> Cube:
landsea_mask = _get_fx_mask(
ancillary_var.core_data(), mask_out, ancillary_var.var_name
)
cube.data = _apply_mask(
cube.data = apply_mask(
landsea_mask,
cube.core_data(),
cube.ancillary_variable_dims(ancillary_var),
Expand Down Expand Up @@ -212,7 +193,7 @@ def mask_landseaice(cube: Cube, mask_out: Literal["landsea", "ice"]) -> Cube:
landseaice_mask = _get_fx_mask(
ancillary_var.core_data(), mask_out, ancillary_var.var_name
)
cube.data = _apply_mask(
cube.data = apply_mask(
landseaice_mask,
cube.core_data(),
cube.ancillary_variable_dims(ancillary_var),
Expand Down Expand Up @@ -350,10 +331,7 @@ def _mask_with_shp(cube, shapefilename, region_indices=None):
else:
mask |= shp_vect.contains(region, x_p_180, y_p_90)

if cube.has_lazy_data():
mask = da.array(mask)

cube.data = _apply_mask(
cube.data = apply_mask(
mask,
cube.core_data(),
cube.coord_dims("latitude") + cube.coord_dims("longitude"),
Expand Down
31 changes: 1 addition & 30 deletions esmvalcore/preprocessor/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from esmvalcore.exceptions import ESMValCoreDeprecationWarning
from esmvalcore.iris_helpers import has_irregular_grid, has_unstructured_grid
from esmvalcore.preprocessor._shared import (
_rechunk_aux_factory_dependencies,
get_array_module,
get_dims_along_axes,
preserve_float_dtype,
Expand Down Expand Up @@ -1174,36 +1175,6 @@ def parse_vertical_scheme(scheme):
return scheme, extrap_scheme


def _rechunk_aux_factory_dependencies(
cube: iris.cube.Cube,
coord_name: str,
) -> iris.cube.Cube:
"""Rechunk coordinate aux factory dependencies.

This ensures that the resulting coordinate has reasonably sized
chunks that are aligned with the cube data for optimal computational
performance.
"""
# Workaround for https://github.com/SciTools/iris/issues/5457
try:
factory = cube.aux_factory(coord_name)
except iris.exceptions.CoordinateNotFoundError:
return cube

cube = cube.copy()
cube_chunks = cube.lazy_data().chunks
for coord in factory.dependencies.values():
coord_dims = cube.coord_dims(coord)
if coord_dims:
coord = coord.copy()
chunks = tuple(cube_chunks[i] for i in coord_dims)
coord.points = coord.lazy_points().rechunk(chunks)
if coord.has_bounds():
coord.bounds = coord.lazy_bounds().rechunk(chunks + (None,))
cube.replace_coord(coord)
return cube


@preserve_float_dtype
def extract_levels(
cube: iris.cube.Cube,
Expand Down
82 changes: 82 additions & 0 deletions esmvalcore/preprocessor/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,85 @@ def get_dims_along_coords(
"""Get a tuple with the dimensions along one or more coordinates."""
dims = {d for coord in coords for d in _get_dims_along(cube, coord)}
return tuple(sorted(dims))


def apply_mask(
mask: np.ndarray | da.Array,
array: np.ndarray | da.Array,
dim_map: Iterable[int],
) -> np.ma.MaskedArray | da.Array:
"""Apply a (broadcasted) mask on an array.

Parameters
----------
mask:
The mask to apply to array.
array:
The array to mask out.
dim_map : :class:`list`, :class:`tuple` etc
A mapping of the dimensions of *mask* to their corresponding
dimension in *array*. *dim_map* must be the same length as the
number of dimensions in *mask*. Each element of *dim_map*
corresponds to a dimension of *mask* and its value provides
the index in *array* which the dimension of *mask* corresponds
to, so the first element of *dim_map* gives the index of *array*
that corresponds to the first dimension of *mask* etc.

Returns
-------
np.ma.MaskedArray or da.Array:
A copy of the input array with the mask applied.

"""
if isinstance(array, da.Array):
array_chunks = array.chunks
# If the mask is not a Dask array yet, we make it into a Dask array
# before broadcasting to avoid inserting a large array into the Dask
# graph.
mask_chunks = tuple(array_chunks[i] for i in dim_map)
mask = da.asarray(mask, chunks=mask_chunks)
else:
array_chunks = None

mask = iris.util.broadcast_to_shape(
mask, array.shape, dim_map=dim_map, chunks=array_chunks
)

array_module = get_array_module(mask, array)
return array_module.ma.masked_where(mask, array)


def _rechunk_aux_factory_dependencies(
cube: iris.cube.Cube,
coord_name: str | None = None,
) -> iris.cube.Cube:
"""Rechunk coordinate aux factory dependencies.

This ensures that the resulting coordinate has reasonably sized
chunks that are aligned with the cube data for optimal computational
performance.
"""
# Workaround for https://github.com/SciTools/iris/issues/5457
if coord_name is None:
factories = cube.aux_factories
else:
try:
factories = [cube.aux_factory(coord_name)]
except iris.exceptions.CoordinateNotFoundError:
return cube

cube = cube.copy()
cube_chunks = cube.lazy_data().chunks
for factory in factories:
for coord in factory.dependencies.values():
coord_dims = cube.coord_dims(coord)
if coord_dims:
coord = coord.copy()
chunks = tuple(cube_chunks[i] for i in coord_dims)
coord.points = coord.lazy_points().rechunk(chunks)
if coord.has_bounds():
coord.bounds = coord.lazy_bounds().rechunk(
chunks + (None,)
)
cube.replace_coord(coord)
return cube
4 changes: 4 additions & 0 deletions esmvalcore/preprocessor/_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,10 @@ def timeseries_filter(
# Apply filter
(agg, agg_kwargs) = get_iris_aggregator(filter_stats, **operator_kwargs)
agg_kwargs["weights"] = wgts
if cube.has_lazy_data():
# Ensure the cube data chunktype is np.MaskedArray so rolling_window
# does not ignore a potential mask.
cube.data = da.ma.masked_array(cube.core_data())
cube = cube.rolling_window("time", agg, len(wgts), **agg_kwargs)

return cube
Expand Down
25 changes: 0 additions & 25 deletions tests/unit/preprocessor/_mask/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import tests
from esmvalcore.preprocessor._mask import (
_apply_mask,
_get_fx_mask,
count_spells,
mask_above_threshold,
Expand Down Expand Up @@ -59,30 +58,6 @@ def setUp(self):
)
self.fx_data = np.array([20.0, 60.0, 50.0])

def test_apply_fx_mask_on_nonmasked_data(self):
"""Test _apply_fx_mask func."""
dummy_fx_mask = np.ma.array((True, False, True))
app_mask = _apply_mask(
dummy_fx_mask, self.time_cube.data[0:3].astype("float64")
)
fixed_mask = np.ma.array(
self.time_cube.data[0:3].astype("float64"), mask=dummy_fx_mask
)
self.assert_array_equal(fixed_mask, app_mask)

def test_apply_fx_mask_on_masked_data(self):
"""Test _apply_fx_mask func."""
dummy_fx_mask = np.ma.array((True, True, True))
masked_data = np.ma.array(
self.time_cube.data[0:3].astype("float64"),
mask=np.ma.array((False, True, False)),
)
app_mask = _apply_mask(dummy_fx_mask, masked_data)
fixed_mask = np.ma.array(
self.time_cube.data[0:3].astype("float64"), mask=dummy_fx_mask
)
self.assert_array_equal(fixed_mask, app_mask)

def test_count_spells(self):
"""Test count_spells func."""
ref_spells = count_spells(self.time_cube.data, -1000.0, 0, 1)
Expand Down
Loading