Skip to content

Commit

Permalink
Merge pull request #12 from JaneliaSciComp/feat/support-overwrite-cre…
Browse files Browse the repository at this point in the history
…ate-group

support `overwrite` keyword argument when creating a multiscale group
  • Loading branch information
d-v-b authored Aug 29, 2024
2 parents 77610a0 + b11e4c0 commit f496a07
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 56 deletions.
13 changes: 4 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,22 @@ default_language_version:
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.3
rev: v0.6.3
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-yaml
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.11.2
hooks:
- id: mypy
files: zarr
args: []
additional_dependencies:
- types-redis
- types-setuptools
4 changes: 0 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ print(group.attrs.asdict())
'multiscales': (
{
'version': '0.4',
'name': None,
'type': None,
'metadata': None,
'datasets': (
{
'path': 's0',
Expand All @@ -196,7 +193,6 @@ print(group.attrs.asdict())
{'name': 'x', 'type': 'space', 'unit': 'meter'},
{'name': 'y', 'type': 'space', 'unit': 'meter'},
),
'coordinateTransformations': None,
},
)
}
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
dependencies = [
"zarr<3.0.0",
"xarray >= 2023.2.0",
"pydantic-ome-ngff == 0.5.3",
"pydantic-ome-ngff == 0.6.0",
"pint >= 0.24",
"pydantic >= 2.0.0, <3",
"dask >= 2022.3.0"
Expand All @@ -45,6 +45,7 @@ dependencies = [
"pytest-examples >= 0.0.9",
"requests",
"aiohttp",
"dask"
]

[[tool.hatch.envs.test.matrix]]
Expand Down
62 changes: 51 additions & 11 deletions src/xarray_ome_ngff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from importlib.metadata import version as _version
from zarr.storage import BaseStore
from xarray_ome_ngff.array_wrap import (
ArrayWrapperSpec,
BaseArrayWrapper,
DaskArrayWrapper,
ZarrArrayWrapper,
)
Expand All @@ -13,7 +12,7 @@
if TYPE_CHECKING:
from typing import Literal
from pydantic_ome_ngff.v04.multiscale import Group as MultiscaleGroupV04

from numcodecs.abc import Codec
from xarray import DataArray
import zarr
from xarray_ome_ngff.v04 import multiscale as multiscale_v04
Expand All @@ -24,7 +23,9 @@
# todo: remove the need to specify the version for reading
def read_multiscale_group(
group: zarr.Group,
array_wrapper: BaseArrayWrapper | ArrayWrapperSpec = ZarrArrayWrapper(),
array_wrapper: ZarrArrayWrapper
| DaskArrayWrapper
| ArrayWrapperSpec = ZarrArrayWrapper(),
ngff_version: Literal["0.4"] = "0.4",
**kwargs,
) -> dict[str, DataArray]:
Expand All @@ -39,7 +40,7 @@ def read_multiscale_group(
----------
group: zarr.Group
A handle for the Zarr group that contains the OME-NGFF metadata.
array_wrapper: BaseArrayWrapper | ArrayWrapperSpec, default is ZarrArrayWrapper
array_wrapper: ZarrArrayWrapper | DaskArrayWrapper | ArrayWrapperSpec, default = ZarrArrayWrapper
Either an object that implements `BaseArrayWrapper`, or a dict model of such a subclass,
which will be resolved to an object implementing `BaseArrayWrapper`. This object has a
`wrap` method that takes an instance of `zarr.Array` and returns another array-like object.
Expand All @@ -52,13 +53,17 @@ def read_multiscale_group(
if ngff_version not in NGFF_VERSIONS:
raise ValueError(f"Unsupported NGFF version: {ngff_version}")
if ngff_version == "0.4":
return multiscale_v04.read_group(group, array_wrapper=array_wrapper, **kwargs)
return multiscale_v04.read_multiscale_group(
group, array_wrapper=array_wrapper, **kwargs
)


# todo: remove the need to specify the version for reading
def read_multiscale_array(
array: zarr.Array,
array_wrapper: BaseArrayWrapper | ArrayWrapperSpec = ZarrArrayWrapper(),
array_wrapper: ZarrArrayWrapper
| DaskArrayWrapper
| ArrayWrapperSpec = ZarrArrayWrapper(),
ngff_version: Literal["0.4"] = "0.4",
**kwargs,
) -> DataArray:
Expand All @@ -80,14 +85,19 @@ def read_multiscale_array(
if ngff_version not in NGFF_VERSIONS:
raise ValueError(f"Unsupported NGFF version: {ngff_version}")
if ngff_version == "0.4":
return multiscale_v04.read_array(array, array_wrapper=array_wrapper, **kwargs)
return multiscale_v04.read_multiscale_array(
array, array_wrapper=array_wrapper, **kwargs
)


def model_multiscale_group(
*,
arrays: dict[str, DataArray],
transform_precision: int | None = None,
ngff_version: Literal["0.4"] = "0.4",
chunks: tuple[int, ...] | tuple[tuple[int, ...]] | Literal["auto"] = "auto",
compressor: Codec | None = multiscale_v04.DEFAULT_COMPRESSOR,
fill_value: Any = 0,
) -> MultiscaleGroupV04:
"""
Create a model of an OME-NGFF multiscale group from a dict of `xarray.DataArray`.
Expand All @@ -104,12 +114,26 @@ def model_multiscale_group(
x decimal places using `numpy.round(transform, x)`.
ngff_version: Literal["0.4"]
The OME-NGFF version to use.
chunks: tuple[int] | tuple[tuple[int, ...]] | Literal["auto"], default = "auto"
The chunks for the arrays in the multiscale group.
If the string "auto" is provided, each array will have chunks set to the zarr-python default
value, which depends on the shape and dtype of the array.
If a single sequence of ints is provided, then this defines the chunks for all arrays.
If a sequence of sequences of ints is provided, then this defines the chunks for each array.
compressor: Codec | None, default = numcodecs.ZStd.
The compressor to use for the arrays. Default is `numcodecs.ZStd`.
fill_value: Any
The fill value for the Zarr arrays.
"""
if ngff_version not in NGFF_VERSIONS:
raise ValueError(f"Unsupported NGFF version: {ngff_version}")
if ngff_version == "0.4":
return multiscale_v04.model_group(
arrays=arrays, transform_precision=transform_precision
return multiscale_v04.model_multiscale_group(
arrays=arrays,
transform_precision=transform_precision,
compressor=compressor,
chunks=chunks,
fill_value=fill_value,
)


Expand All @@ -120,6 +144,9 @@ def create_multiscale_group(
arrays: dict[str, DataArray],
transform_precision: int | None = None,
ngff_version: Literal["0.4"] = "0.4",
chunks: tuple[int, ...] | tuple[tuple[int, ...]] | Literal["auto"] = "auto",
compressor: Codec | None = multiscale_v04.DEFAULT_COMPRESSOR,
fill_value: Any = 0,
) -> MultiscaleGroupV04:
"""
store: zarr.storage.BaseStore
Expand All @@ -134,15 +161,28 @@ def create_multiscale_group(
x decimal places using `numpy.round(transform, x)`.
ngff_version: Literal["0.4"]
The OME-NGFF version to use.
chunks: tuple[int] | tuple[tuple[int, ...]] | Literal["auto"], default = "auto"
The chunks for the arrays in the multiscale group.
If the string "auto" is provided, each array will have chunks set to the zarr-python default
value, which depends on the shape and dtype of the array.
If a single sequence of ints is provided, then this defines the chunks for all arrays.
If a sequence of sequences of ints is provided, then this defines the chunks for each array.
compressor: Codec | None, default = numcodecs.ZStd.
The compressor to use for the arrays. Default is `numcodecs.ZStd`.
fill_value: Any
The fill value for the Zarr arrays.
"""
if ngff_version not in NGFF_VERSIONS:
raise ValueError(f"Unsupported NGFF version: {ngff_version}")
if ngff_version == "0.4":
return multiscale_v04.create_group(
return multiscale_v04.create_multiscale_group(
store=store,
path=path,
arrays=arrays,
transform_precision=transform_precision,
compressor=compressor,
chunks=chunks,
fill_value=fill_value,
)


Expand Down
51 changes: 38 additions & 13 deletions src/xarray_ome_ngff/v04/multiscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Literal

from pydantic_ome_ngff.v04.multiscale import Group as MultiscaleGroup
from pydantic_ome_ngff.v04 import MultiscaleGroup

from xarray_ome_ngff.array_wrap import (
ArrayWrapperSpec,
Expand Down Expand Up @@ -33,6 +33,8 @@

from xarray_ome_ngff.core import CoordinateAttrs, ureg

DEFAULT_COMPRESSOR = Zstd(3)


def multiscale_metadata(
arrays: dict[str, DataArray],
Expand Down Expand Up @@ -322,12 +324,12 @@ def normalize_transforms(
return out_scale, out_trans


def model_group(
def model_multiscale_group(
*,
arrays: dict[str, DataArray],
transform_precision: int | None = None,
chunks: tuple[int, ...] | tuple[tuple[int, ...]] | Literal["auto"] = "auto",
compressor: Codec | None = Zstd(3),
compressor: Codec | None = DEFAULT_COMPRESSOR,
fill_value: Any = 0,
) -> MultiscaleGroup:
"""
Expand All @@ -341,8 +343,18 @@ def model_group(
A mapping from strings to `xarray.DataArray`.
transform_precision: int | None, default is None
Whether, and how much, to round the transformations estimated from the coordinates.
The default (`None`) results in no rounding; specifying an `int` x will round transforms to
x decimal places using `numpy.round(transform, x)`.
The default (`None`) results in no rounding; if `transform_precision` is an int, then
transforms will be rounded to `transform_precision` decimal places using `numpy.round`.
chunks: tuple[int] | tuple[tuple[int, ...]] | Literal["auto"], default = "auto"
The chunks for the arrays in the multiscale group.
If the string "auto" is provided, each array will have chunks set to the zarr-python default
value, which depends on the shape and dtype of the array.
If a single sequence of ints is provided, then this defines the chunks for all arrays.
If a sequence of sequences of ints is provided, then this defines the chunks for each array.
compressor: Codec | None, default = numcodecs.ZStd.
The compressor to use for the arrays. Default is `numcodecs.ZStd`.
fill_value: Any
The fill value for the Zarr arrays.
"""

# pluralses of plurals
Expand Down Expand Up @@ -377,15 +389,16 @@ def model_group(
return group


def create_group(
def create_multiscale_group(
*,
store: BaseStore,
path: str,
arrays: dict[str, DataArray],
transform_precision: int | None = None,
chunks: tuple[int, ...] | tuple[tuple[int, ...]] | Literal["auto"] = "auto",
compressor: Codec | None = Zstd(3),
compressor: Codec | None = DEFAULT_COMPRESSOR,
fill_value: Any = 0,
overwrite: bool = False,
) -> zarr.Group:
"""
Create Zarr group that complies with 0.4 of the OME-NGFF multiscale specification from a dict
Expand All @@ -402,20 +415,32 @@ def create_group(
Whether, and how much, to round the transformations estimated from the coordinates.
The default (`None`) results in no rounding; specifying an `int` x will round transforms to
x decimal places using `numpy.round(transform, x)`.
chunks: tuple[int] | tuple[tuple[int, ...]] | Literal["auto"], default = "auto"
The chunks for the arrays in the multiscale group.
If the string "auto" is provided, each array will have chunks set to the zarr-python default
value, which depends on the shape and dtype of the array.
If a single sequence of ints is provided, then this defines the chunks for all arrays.
If a sequence of sequences of ints is provided, then this defines the chunks for each array.
compressor: Codec | None, default = numcodecs.ZStd.
The compressor to use for the arrays. Default is `numcodecs.ZStd`.
fill_value: Any
The fill value for the Zarr arrays.
overwrite: bool, default = False
Whether to overwrite an existing Zarr array or group at `path`. Default is False, which will
result in an exception being raised if a Zarr array or group already exists at `path`.
"""

model = model_group(
model = model_multiscale_group(
arrays=arrays,
transform_precision=transform_precision,
chunks=chunks,
compressor=compressor,
fill_value=fill_value,
)
return model.to_zarr(store, path)
return model.to_zarr(store, path, overwrite=overwrite)


def read_group(
def read_multiscale_group(
group: zarr.Group,
*,
array_wrapper: (
Expand Down Expand Up @@ -471,7 +496,7 @@ def read_group(
return result


def read_array(
def read_multiscale_array(
array: zarr.Array,
*,
array_wrapper: (
Expand Down Expand Up @@ -538,6 +563,6 @@ def read_array(
except KeyError:
parent = get_parent(parent)
raise FileNotFoundError(
f"Could not find version 0.4 OME-NGFF multiscale metadata in any Zarr groups in the "
"Could not find version 0.4 OME-NGFF multiscale metadata in any Zarr groups"
f"ancestral to the array at {array.path}"
)
Loading

0 comments on commit f496a07

Please sign in to comment.