From 69c9ef87f6f364675a41ad2139eba7a56d02a261 Mon Sep 17 00:00:00 2001 From: Davis Bennett Date: Tue, 23 Apr 2024 15:45:21 +0200 Subject: [PATCH] feat: fix type annotations, rename fuse_coordinate_transforms to normalize_transforms and simplify transform logic in coordinate inference. --- pyproject.toml | 20 +++++ src/xarray_ome_ngff/__init__.py | 2 +- src/xarray_ome_ngff/array_wrap.py | 51 +++++------ src/xarray_ome_ngff/v04/multiscale.py | 125 ++++++++++++++------------ tests/test_v04.py | 2 +- 5 files changed, 112 insertions(+), 88 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0b787aa..e0c586b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,26 @@ pytest-examples = "^0.0.9" requests = "^2.31.0" aiohttp = "^3.9.3" +[tool.mypy] +plugins = [ + "pydantic.mypy" +] + +follow_imports = "silent" +warn_redundant_casts = true +warn_unused_ignores = true +disallow_any_generics = true +check_untyped_defs = true +no_implicit_reexport = true + +# for strict mypy: (this is the tricky one :-)) +disallow_untyped_defs = true + +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/src/xarray_ome_ngff/__init__.py b/src/xarray_ome_ngff/__init__.py index 3d6ca38..10d642d 100644 --- a/src/xarray_ome_ngff/__init__.py +++ b/src/xarray_ome_ngff/__init__.py @@ -120,7 +120,7 @@ def create_multiscale_group( arrays: dict[str, DataArray], transform_precision: int | None = None, ngff_version: Literal["0.4"] = "0.4", -): +) -> MultiscaleGroupV04: """ store: zarr.storage.BaseStore The storage backend for the Zarr hierarchy. diff --git a/src/xarray_ome_ngff/array_wrap.py b/src/xarray_ome_ngff/array_wrap.py index 5b68eb2..b8207c4 100644 --- a/src/xarray_ome_ngff/array_wrap.py +++ b/src/xarray_ome_ngff/array_wrap.py @@ -7,21 +7,21 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from importlib.util import find_spec +from dask.array.core import Array as DaskArray import numpy as np import zarr @runtime_checkable class Arrayish(Protocol): - dtype: np.dtype + dtype: np.dtype[Any] shape: tuple[int, ...] - def __getitem__(self, *args) -> Self: ... + def __getitem__(self, *args: Any) -> Self: ... class ArrayWrapperSpec(TypedDict): - name: Literal["dask"] + name: str config: dict[str, Any] @@ -31,18 +31,22 @@ class DaskArrayWrapperConfig(TypedDict): """ chunks: str | int | tuple[int, ...] | tuple[tuple[int, ...], ...] - meta: Any = None + meta: Any inline_array: bool class ZarrArrayWrapperSpec(ArrayWrapperSpec): - name: Literal["zarr_array"] - config: dict[str, Any] = {} + # type checkers hate that the base class defines `name` to be mutable, but this is immutable. + # this will be fixed with python allows declaring typeddict fields as read-only. + name: Literal["zarr_array"] # type: ignore + config: dict[str, Any] # type: ignore class DaskArrayWrapperSpec(ArrayWrapperSpec): - name: Literal["dask_array"] - config: DaskArrayWrapperConfig + # type checkers hate that the base class defines `name` to be mutable, but this is immutable. + # this will be fixed with python allows declaring typeddict fields as read-only. + name: Literal["dask_array"] # type: ignore + config: DaskArrayWrapperConfig # type: ignore class BaseArrayWrapper(ABC): @@ -56,7 +60,7 @@ class ZarrArrayWrapper(BaseArrayWrapper): An array wrapper that passes a `zarr.Array` instances through unchanged. """ - def wrap(self, data: zarr.Array) -> Arrayish: + def wrap(self, data: zarr.Array) -> zarr.Array: # type: ignore return data @@ -83,20 +87,7 @@ class DaskArrayWrapper(BaseArrayWrapper): meta: Any = None inline_array: bool = True - def __post_init__(self) -> None: - """ - Handle the lack of `dask`. - """ - if find_spec("dask") is None: - msg = ( - "Failed to import `dask` successfully. " - "The `dask` library is required to use `DaskArrayWrapper`." - "Install `dask` into your python environment, e.g. via " - "`pip install dask`, to resolve this issue." - ) - ImportError(msg) - - def wrap(self, data: zarr.Array): + def wrap(self, data: zarr.Array) -> DaskArray: # type: ignore """ Wrap the input in a dask array. """ @@ -107,27 +98,29 @@ def wrap(self, data: zarr.Array): ) -def resolve_wrapper(spec: ArrayWrapperSpec) -> BaseArrayWrapper: +def resolve_wrapper(spec: ArrayWrapperSpec) -> ZarrArrayWrapper | DaskArrayWrapper: """ Convert an `ArrayWrapperSpec` into the corresponding `BaseArrayWrapper` subclass. """ if spec["name"] == "dask_array": - spec = cast(DaskArrayWrapperConfig, spec) + spec = cast(DaskArrayWrapperSpec, spec) # type: ignore return DaskArrayWrapper(**spec["config"]) elif spec["name"] == "zarr_array": - spec = cast(ZarrArrayWrapperSpec, spec) + spec = cast(ZarrArrayWrapperSpec, spec) # type: ignore return ZarrArrayWrapper(**spec["config"]) else: raise ValueError(f"Spec {spec} is not recognized.") -def parse_wrapper(data: ArrayWrapperSpec | BaseArrayWrapper): +def parse_wrapper( + data: ArrayWrapperSpec | DaskArrayWrapper | ZarrArrayWrapper, +) -> DaskArrayWrapper | ZarrArrayWrapper: """ Parse the input into a `BaseArrayWrapper` subclass. If the input is already `BaseArrayWrapper`, it is returned as-is. Otherwise, the input is presumed to be `ArrayWrapperSpec` and is passed to `resolve_wrapper`. """ - if isinstance(data, BaseArrayWrapper): + if isinstance(data, (ZarrArrayWrapper, DaskArrayWrapper)): return data return resolve_wrapper(data) diff --git a/src/xarray_ome_ngff/v04/multiscale.py b/src/xarray_ome_ngff/v04/multiscale.py index 45d8c52..8771824 100644 --- a/src/xarray_ome_ngff/v04/multiscale.py +++ b/src/xarray_ome_ngff/v04/multiscale.py @@ -1,9 +1,8 @@ from __future__ import annotations -from pydantic_ome_ngff.v04.multiscale import Group +from pydantic_ome_ngff.v04.multiscale import Group as MultiscaleGroup from typing import TYPE_CHECKING -from xarray_ome_ngff.array_wrap import ArrayWrapperSpec -from xarray_ome_ngff.array_wrap import BaseArrayWrapper +from xarray_ome_ngff.array_wrap import ArrayWrapperSpec, DaskArrayWrapper from xarray_ome_ngff.array_wrap import ZarrArrayWrapper from xarray_ome_ngff.array_wrap import parse_wrapper @@ -106,7 +105,7 @@ def multiscale_metadata( def transforms_from_coords( coords: dict[str, DataArray], normalize_units: bool = True, - infer_axis_type=True, + infer_axis_type: bool = True, transform_precision: int | None = None, ) -> tuple[tuple[Axis, ...], tuple[VectorScale, VectorTranslation]]: """ @@ -143,9 +142,9 @@ def transforms_from_coords( Both transformations are are derived from the coordinates of the input array. """ - translate = () - scale = () - axes = () + translate: tuple[float, ...] = () + scale: tuple[float, ...] = () + axes: tuple[Axis, ...] = () for dim, coord in coords.items(): if ndim := len(coord.dims) != 1: @@ -194,8 +193,8 @@ def transforms_from_coords( ), ) if transform_precision is not None: - scale = np.round(scale, transform_precision) - translate = np.round(translate, transform_precision) + scale = tuple(np.round(scale, transform_precision).tolist()) + translate = tuple(np.round(translate, transform_precision).tolist()) transforms = ( VectorScale(scale=scale), @@ -205,10 +204,11 @@ def transforms_from_coords( def coords_from_transforms( + *, axes: Sequence[Axis], - transforms: Sequence[tuple[VectorScale] | tuple[VectorScale, VectorTranslation]], + transforms: tuple[VectorScale, VectorTranslation], shape: tuple[int, ...], -) -> tuple[DataArray]: +) -> tuple[DataArray, ...]: """ Given an output shape, convert a sequence of Axis objects and a corresponding sequence of coordinateTransform objects into xarray-compatible coordinates. @@ -223,44 +223,43 @@ def coords_from_transforms( result = [] + for tx in transforms: + if tx.type == "translation": + if len(tx.translation) != len(axes): + msg = ( + f"Translation parameter has length {len(tx.translation)}. " + f"This does not match the number of axes {len(axes)}." + ) + raise ValueError(msg) + + elif tx.type == "scale": + if len(tx.scale) != len(axes): + msg = ( + f"Scale parameter has length {len(tx.scale)}. " + f"This does not match the number of axes {len(axes)}." + ) + raise ValueError(msg) + elif tx.type == "identity": + pass + else: + msg = ( + f"Transform type {tx.type} not recognized. Must be one of scale, " + "translation, or identity" + ) + raise ValueError(msg) + for idx, axis in enumerate(axes): base_coord = np.arange(shape[idx], dtype="float") name = axis.name unit = axis.unit # apply transforms in order for tx in transforms: - if isinstance(getattr(tx, "path", None), str): - msg = ( - f"Problematic transform: {tx}. This library cannot handle " - "transforms with paths. Resolve this path to a literal scale or " - "translation" - ) - raise ValueError(msg) - if tx.type == "translation": - if len(tx.translation) != len(axes): - msg = ( - f"Translation parameter has length {len(tx.translation)}. " - f"This does not match the number of axes {len(axes)}." - ) - raise ValueError(msg) base_coord += tx.translation[idx] elif tx.type == "scale": - if len(tx.scale) != len(axes): - msg = ( - f"Scale parameter has length {len(tx.scale)}. " - f"This does not match the number of axes {len(axes)}." - ) - raise ValueError(msg) base_coord *= tx.scale[idx] elif tx.type == "identity": pass - else: - msg = ( - f"Transform type {tx.type} not recognized. Must be one of scale, " - "translation, or identity" - ) - raise ValueError(msg) result.append( DataArray( @@ -273,14 +272,14 @@ def coords_from_transforms( return tuple(result) -def fuse_coordinate_transforms( +def normalize_transforms( base_transforms: ( None | tuple[()] | tuple[VectorScale] | tuple[VectorScale, VectorTranslation] ), dset_transforms: tuple[VectorScale] | tuple[VectorScale, VectorTranslation], ) -> tuple[VectorScale, VectorTranslation]: - if base_transforms is None or base_transforms == (): + if base_transforms is None or len(base_transforms) == 0: out_scale = dset_transforms[0] if len(dset_transforms) == 1: out_trans = VectorTranslation(translation=(0,) * len(out_scale.scale)) @@ -299,10 +298,16 @@ def fuse_coordinate_transforms( out_trans = dset_transforms[1] else: base_trans = base_transforms[1] - dset_trans = dset_transforms[1] + if len(dset_transforms) == 2: + dset_trans = dset_transforms[1] + else: + dset_trans = VectorTranslation( + translation=(0,) * len(base_trans.translation) + ) out_trans = VectorTranslation( translation=tuple( - b + d for b, d in zip(base_trans.scale, dset_trans.scale) + b + d + for b, d in zip(base_trans.translation, dset_trans.translation) ) ) @@ -311,7 +316,7 @@ def fuse_coordinate_transforms( def model_group( *, arrays: dict[str, DataArray], transform_precision: int | None = None -) -> Group: +) -> MultiscaleGroup: """ Create a model of an OME-NGFF multiscale group from a dict of `xarray.DataArray`. The dimensions / coordinates of the arrays will be used to infer OME-NGFF axis metadata, as well @@ -346,9 +351,9 @@ def model_group( ) raise ValueError(msg) - group = Group.from_arrays( - arrays=arrays.values(), - paths=arrays.keys(), + group = MultiscaleGroup.from_arrays( + arrays=tuple(arrays.values()), + paths=tuple(arrays.keys()), axes=axeses[0], scales=[tx[0].scale for tx in transformses], translations=[tx[1].translation for tx in transformses], @@ -362,7 +367,7 @@ def create_group( path: str, arrays: dict[str, DataArray], transform_precision: int | None = None, -): +) -> MultiscaleGroup: """ Create Zarr group that complies with 0.4 of the OME-NGFF multiscale specification from a dict of `xarray.DataArray`. @@ -388,7 +393,9 @@ def create_group( def read_group( group: zarr.Group, *, - array_wrapper: BaseArrayWrapper | ArrayWrapperSpec = ZarrArrayWrapper(), + array_wrapper: ( + ZarrArrayWrapper | DaskArrayWrapper | ArrayWrapperSpec + ) = ZarrArrayWrapper(), multiscales_index: int = 0, ) -> dict[str, DataArray]: """ @@ -421,23 +428,25 @@ def read_group( """ result: dict[str, DataArray] = {} # parse the zarr group as a multiscale group - multiscale_group_parsed = Group.from_zarr(group) + multiscale_group_parsed = MultiscaleGroup.from_zarr(group) multi_meta = multiscale_group_parsed.attributes.multiscales[multiscales_index] multi_tx = multi_meta.coordinateTransformations wrapper_parsed = parse_wrapper(array_wrapper) for dset in multi_meta.datasets: - tx_fused = fuse_coordinate_transforms(multi_tx, dset.coordinateTransformations) + tx_fused = normalize_transforms(multi_tx, dset.coordinateTransformations) arr_z: zarr.Array = group[dset.path] arr_wrapped = wrapper_parsed.wrap(arr_z) - coords = coords_from_transforms(multi_meta.axes, tx_fused, arr_z.shape) + coords = coords_from_transforms( + axes=multi_meta.axes, transforms=tx_fused, shape=arr_z.shape + ) arr_out = DataArray(data=arr_wrapped, coords=coords) result[dset.path] = arr_out return result -def get_parent(node: zarr.Group | zarr.Array): +def get_parent(node: zarr.Group | zarr.Array) -> zarr.Group: """ Get the parent node of a Zarr array or group """ @@ -450,13 +459,15 @@ def get_parent(node: zarr.Group | zarr.Array): new_full_path, new_node_path = os.path.split(full_path) return zarr.open_group(type(node.store)(new_full_path), path=new_node_path) else: - return zarr.open(store=node.store, path=os.path.split(node.path)[0]) + return zarr.open_group(store=node.store, path=os.path.split(node.path)[0]) def read_array( array: zarr.Array, *, - array_wrapper: BaseArrayWrapper | ArrayWrapperSpec = ZarrArrayWrapper(), + array_wrapper: ( + ZarrArrayWrapper | DaskArrayWrapper | ArrayWrapperSpec + ) = ZarrArrayWrapper(), ) -> DataArray: """ Read a single Zarr array as an `xarray.DataArray`, using version 0.4 OME-NGFF multiscale @@ -473,7 +484,7 @@ def read_array( ---------- array: zarr.Array A Zarr array that is part of a version 0.4 OME-NGFF multiscale image. - array_wrapper: BaseArrayWrapper | ArrayWrapperSpec, default is ZarrArrayWrapper + array_wrapper: ZarrArrayWrapper | DaskArrayWrapper | ArrayWrapperSpec, default is ZarrArrayWrapper The array wrapper class to use when converting the Zarr array to an `xarray.DataArray`. Returns ------- @@ -501,17 +512,17 @@ def read_array( parent_path = parent.path array_path_rel = os.path.relpath(full_path, parent_path) try: - model = Group.from_zarr(parent) + model = MultiscaleGroup.from_zarr(parent) for multi in model.attributes.multiscales: multi_tx = multi.coordinateTransformations for dset in multi.datasets: if dset.path == array_path_rel: - tx_fused = fuse_coordinate_transforms( + tx_fused = normalize_transforms( multi_tx, dset.coordinateTransformations ) coords = coords_from_transforms( - multi.axes, tx_fused, array.shape + axes=multi.axes, transforms=tx_fused, shape=array.shape ) arr_wrapped = wrapper.wrap(array) return DataArray(arr_wrapped, coords=coords) diff --git a/tests/test_v04.py b/tests/test_v04.py index 272b75b..8b55a34 100644 --- a/tests/test_v04.py +++ b/tests/test_v04.py @@ -193,7 +193,7 @@ def test_create_coords(): VectorTranslation(translation=(1, 2)), ] - coords = coords_from_transforms(axes, transforms, shape) + coords = coords_from_transforms(axes=axes, transforms=transforms, shape=shape) assert coords[0].equals( DataArray( np.array([1.0, 2.0, 3.0]),