Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate axes types v0.4 #124

Merged
merged 17 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion ome_zarr/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,15 @@ def version(self) -> str:
return "0.3"


CurrentFormat = FormatV03
class FormatV04(FormatV03):
"""
Changelog: axes is list of dicts,
introduce transformations in multiscales (Nov 2021)
"""

@property
def version(self) -> str:
return "0.4"


CurrentFormat = FormatV04
13 changes: 9 additions & 4 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def __init__(self, node: Node) -> None:
"version", "0.1"
) # should this be matched with Format.version?
datasets = multiscales[0]["datasets"]
# axes field was introduced in 0.3, before all data was 5d
axes = tuple(multiscales[0].get("axes", ["t", "c", "z", "y", "x"]))
if len(set(axes) - axes_values) > 0:
axes = multiscales[0].get("axes")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the main implication of not setting a default value here?
For 0.1/0.2 data, this means, the node.metadata["axes"] might be None as opposed to ["t", "c", "z", "y", "x"] previously i.e. we are preserving the value stored in the metadata? Is there an impact on clients relying on node.metadata["axes"]?

if version == "0.3" and len(set(axes) - axes_values) > 0:
# TODO: validate axis for > V0.3 ?
will-moore marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(f"Invalid axes names: {set(axes) - axes_values}")
node.metadata["axes"] = axes
datasets = [d["path"] for d in datasets]
Expand All @@ -301,7 +301,12 @@ def __init__(self, node: Node) -> None:
for c in data.chunks
]
LOGGER.info("resolution: %s", resolution)
LOGGER.info(" - shape %s = %s", axes, data.shape)
axes_names = None
if axes is not None:
axes_names = tuple(
axis if isinstance(axis, str) else axis["name"] for axis in axes
)
LOGGER.info(" - shape %s = %s", axes_names, data.shape)
LOGGER.info(" - chunks = %s", chunk_sizes)
LOGGER.info(" - dtype = %s", data.dtype)
node.data.append(data)
Expand Down
120 changes: 96 additions & 24 deletions ome_zarr/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""
import logging
from typing import Any, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import zarr
Expand All @@ -13,18 +13,75 @@

LOGGER = logging.getLogger("ome_zarr.writer")

KNOWN_AXES = {"x": "space", "y": "space", "z": "space", "c": "channel", "t": "time"}

def _validate_axes_names(
ndim: int, axes: Union[str, List[str]] = None, fmt: Format = CurrentFormat()
) -> Union[None, List[str]]:
"""Returns validated list of axes names or raise exception if invalid"""

def _axes_to_dicts(
sbesson marked this conversation as resolved.
Show resolved Hide resolved
axes: Union[List[str], List[Dict[str, str]]]
) -> List[Dict[str, str]]:
"""Returns a list of axis dicts with name and type"""
axes_dicts = []
for axis in axes:
if isinstance(axis, str):
axis_dict = {"name": axis}
if axis in KNOWN_AXES:
axis_dict["type"] = KNOWN_AXES[axis]
sbesson marked this conversation as resolved.
Show resolved Hide resolved
axes_dicts.append(axis_dict)
else:
axes_dicts.append(axis)
return axes_dicts


def _axes_to_names(axes: List[Dict[str, str]]) -> List[str]:
"""Returns a list of axis names"""
axes_names = []
for axis in axes:
if "name" not in axis:
raise ValueError("Axis Dict %s has no 'name'" % axis)
axes_names.append(axis["name"])
return axes_names


def _validate_axes_types(axes_dicts: List[Dict[str, str]]) -> None:
"""
Validate the axes types according to the spec, version 0.4+
"""
axes_types = [axis.get("type") for axis in axes_dicts]
known_types = list(KNOWN_AXES.values())
unknown_types = [atype for atype in axes_types if atype not in known_types]
if len(unknown_types) > 1:
raise ValueError(
"Too many unknown axes types. 1 allowed, found: %s" % unknown_types
)

def _last_index(item: str, item_list: List[Any]) -> int:
return max(loc for loc, val in enumerate(item_list) if val == item)

if "time" in axes_types and _last_index("time", axes_types) > 0:
raise ValueError("'time' axis must be first dimension only")

if axes_types.count("channel") > 1:
raise ValueError("Only 1 axis can be type 'channel'")

if "channel" in axes_types and _last_index(
"channel", axes_types
) > axes_types.index("space"):
raise ValueError("'space' axes must come after 'channel'")


def _validate_axes(
ndim: int = None,
axes: Union[str, List[str], List[Dict[str, str]]] = None,
fmt: Format = CurrentFormat(),
) -> Union[None, List[str], List[Dict[str, str]]]:
"""Returns list of axes valid for fmt.version or raise exception if invalid"""

if fmt.version in ("0.1", "0.2"):
if axes is not None:
LOGGER.info("axes ignored for version 0.1 or 0.2")
return None

# handle version 0.3...
# We can guess axes for 2D and 5D data
if axes is None:
if ndim == 2:
axes = ["y", "x"]
Expand All @@ -37,16 +94,30 @@ def _validate_axes_names(
"axes must be provided. Can't be guessed for 3D or 4D data"
)

# axes may be string e.g. "tczyx"
if isinstance(axes, str):
axes = list(axes)

if len(axes) != ndim:
raise ValueError("axes length must match number of dimensions")
_validate_axes(axes)
return axes
if ndim is not None and len(axes) != ndim:
raise ValueError(
f"axes length ({len(axes)}) must match number of dimensions ({ndim})"
)

# axes may be list of 'x', 'y' or list of {'name': 'x'}
axes_dicts = _axes_to_dicts(axes)
axes_names = _axes_to_names(axes_dicts)

# check names (only enforced for version 0.3)
if fmt.version == "0.3":
_validate_axes_03(axes_names)
return axes_names

_validate_axes_types(axes_dicts)

return axes_dicts


def _validate_axes(axes: List[str], fmt: Format = CurrentFormat()) -> None:
def _validate_axes_03(axes: List[str]) -> None:

val_axes = tuple(axes)
if len(val_axes) == 2:
Expand Down Expand Up @@ -75,7 +146,7 @@ def write_multiscale(
group: zarr.Group,
chunks: Union[Tuple[Any, ...], int] = None,
fmt: Format = CurrentFormat(),
axes: Union[str, List[str]] = None,
axes: Union[str, List[str], List[Dict[str, str]]] = None,
) -> None:
"""
Write a pyramid with multiscale metadata to disk.
Expand All @@ -93,13 +164,13 @@ def write_multiscale(
fmt: Format
The format of the ome_zarr data which should be used.
Defaults to the most current.
axes: str or list of str
the names of the axes. e.g. "tczyx". Not needed for v0.1 or v0.2
or for v0.3 if 2D or 5D. Otherwise this must be provided
axes: str or list of str or list of dict
List of axes dicts, or names. Not needed for v0.1 or v0.2
or if 2D. Otherwise this must be provided
"""

dims = len(pyramid[0].shape)
axes = _validate_axes_names(dims, axes, fmt)
axes = _validate_axes(dims, axes, fmt)

paths = []
for path, dataset in enumerate(pyramid):
Expand All @@ -113,7 +184,7 @@ def write_multiscales_metadata(
group: zarr.Group,
paths: List[str],
fmt: Format = CurrentFormat(),
axes: List[str] = None,
axes: Union[str, List[str], List[Dict[str, str]]] = None,
will-moore marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Write the multiscales metadata in the group.
Expand Down Expand Up @@ -142,8 +213,9 @@ def write_multiscales_metadata(
if fmt.version in ("0.1", "0.2"):
LOGGER.info("axes ignored for version 0.1 or 0.2")
else:
_validate_axes(axes, fmt)
multiscales[0]["axes"] = axes
axes = _validate_axes(axes=axes, fmt=fmt)
if axes is not None:
multiscales[0]["axes"] = axes
group.attrs["multiscales"] = multiscales


Expand All @@ -154,7 +226,7 @@ def write_image(
byte_order: Union[str, List[str]] = "tczyx",
scaler: Scaler = Scaler(),
fmt: Format = CurrentFormat(),
axes: Union[str, List[str]] = None,
axes: Union[str, List[str], List[Dict[str, str]]] = None,
**metadata: JSONDict,
) -> None:
"""Writes an image to the zarr store according to ome-zarr specification
Expand All @@ -179,9 +251,9 @@ def write_image(
fmt: Format
The format of the ome_zarr data which should be used.
Defaults to the most current.
axes: str or list of str
the names of the axes. e.g. "tczyx". Not needed for v0.1 or v0.2
or for v0.3 if 2D or 5D. Otherwise this must be provided
axes: str or list of str or list of dict
List of axes dicts, or names. Not needed for v0.1 or v0.2
or if 2D. Otherwise this must be provided
"""

if image.ndim > 5:
Expand All @@ -195,7 +267,7 @@ def write_image(
axes = None

# check axes before trying to scale
_validate_axes_names(image.ndim, axes, fmt)
_validate_axes(image.ndim, axes, fmt)

if chunks is not None:
chunks = _retuple(chunks, image.shape)
Expand Down
Loading