Skip to content

Commit

Permalink
feat: fix type annotations, rename fuse_coordinate_transforms to norm…
Browse files Browse the repository at this point in the history
…alize_transforms and simplify transform logic in coordinate inference.
  • Loading branch information
d-v-b committed Apr 23, 2024
1 parent 5610216 commit 69c9ef8
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 88 deletions.
20 changes: 20 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@ pytest-examples = "^0.0.9"
requests = "^2.31.0"
aiohttp = "^3.9.3"

[tool.mypy]
plugins = [
"pydantic.mypy"
]

follow_imports = "silent"
warn_redundant_casts = true
warn_unused_ignores = true
disallow_any_generics = true
check_untyped_defs = true
no_implicit_reexport = true

# for strict mypy: (this is the tricky one :-))
disallow_untyped_defs = true

[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
2 changes: 1 addition & 1 deletion src/xarray_ome_ngff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def create_multiscale_group(
arrays: dict[str, DataArray],
transform_precision: int | None = None,
ngff_version: Literal["0.4"] = "0.4",
):
) -> MultiscaleGroupV04:
"""
store: zarr.storage.BaseStore
The storage backend for the Zarr hierarchy.
Expand Down
51 changes: 22 additions & 29 deletions src/xarray_ome_ngff/array_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from importlib.util import find_spec
from dask.array.core import Array as DaskArray
import numpy as np
import zarr


@runtime_checkable
class Arrayish(Protocol):
dtype: np.dtype
dtype: np.dtype[Any]
shape: tuple[int, ...]

def __getitem__(self, *args) -> Self: ...
def __getitem__(self, *args: Any) -> Self: ...


class ArrayWrapperSpec(TypedDict):
name: Literal["dask"]
name: str
config: dict[str, Any]


Expand All @@ -31,18 +31,22 @@ class DaskArrayWrapperConfig(TypedDict):
"""

chunks: str | int | tuple[int, ...] | tuple[tuple[int, ...], ...]
meta: Any = None
meta: Any
inline_array: bool


class ZarrArrayWrapperSpec(ArrayWrapperSpec):
name: Literal["zarr_array"]
config: dict[str, Any] = {}
# type checkers hate that the base class defines `name` to be mutable, but this is immutable.
# this will be fixed with python allows declaring typeddict fields as read-only.
name: Literal["zarr_array"] # type: ignore
config: dict[str, Any] # type: ignore


class DaskArrayWrapperSpec(ArrayWrapperSpec):
name: Literal["dask_array"]
config: DaskArrayWrapperConfig
# type checkers hate that the base class defines `name` to be mutable, but this is immutable.
# this will be fixed with python allows declaring typeddict fields as read-only.
name: Literal["dask_array"] # type: ignore
config: DaskArrayWrapperConfig # type: ignore


class BaseArrayWrapper(ABC):
Expand All @@ -56,7 +60,7 @@ class ZarrArrayWrapper(BaseArrayWrapper):
An array wrapper that passes a `zarr.Array` instances through unchanged.
"""

def wrap(self, data: zarr.Array) -> Arrayish:
def wrap(self, data: zarr.Array) -> zarr.Array: # type: ignore
return data


Expand All @@ -83,20 +87,7 @@ class DaskArrayWrapper(BaseArrayWrapper):
meta: Any = None
inline_array: bool = True

def __post_init__(self) -> None:
"""
Handle the lack of `dask`.
"""
if find_spec("dask") is None:
msg = (
"Failed to import `dask` successfully. "
"The `dask` library is required to use `DaskArrayWrapper`."
"Install `dask` into your python environment, e.g. via "
"`pip install dask`, to resolve this issue."
)
ImportError(msg)

def wrap(self, data: zarr.Array):
def wrap(self, data: zarr.Array) -> DaskArray: # type: ignore
"""
Wrap the input in a dask array.
"""
Expand All @@ -107,27 +98,29 @@ def wrap(self, data: zarr.Array):
)


def resolve_wrapper(spec: ArrayWrapperSpec) -> BaseArrayWrapper:
def resolve_wrapper(spec: ArrayWrapperSpec) -> ZarrArrayWrapper | DaskArrayWrapper:
"""
Convert an `ArrayWrapperSpec` into the corresponding `BaseArrayWrapper` subclass.
"""
if spec["name"] == "dask_array":
spec = cast(DaskArrayWrapperConfig, spec)
spec = cast(DaskArrayWrapperSpec, spec) # type: ignore
return DaskArrayWrapper(**spec["config"])
elif spec["name"] == "zarr_array":
spec = cast(ZarrArrayWrapperSpec, spec)
spec = cast(ZarrArrayWrapperSpec, spec) # type: ignore
return ZarrArrayWrapper(**spec["config"])
else:
raise ValueError(f"Spec {spec} is not recognized.")


def parse_wrapper(data: ArrayWrapperSpec | BaseArrayWrapper):
def parse_wrapper(
data: ArrayWrapperSpec | DaskArrayWrapper | ZarrArrayWrapper,
) -> DaskArrayWrapper | ZarrArrayWrapper:
"""
Parse the input into a `BaseArrayWrapper` subclass.
If the input is already `BaseArrayWrapper`, it is returned as-is.
Otherwise, the input is presumed to be `ArrayWrapperSpec` and is passed to `resolve_wrapper`.
"""
if isinstance(data, BaseArrayWrapper):
if isinstance(data, (ZarrArrayWrapper, DaskArrayWrapper)):
return data
return resolve_wrapper(data)
Loading

0 comments on commit 69c9ef8

Please sign in to comment.