From f5fca18ab014e03cbe2f46d9dafdb5d3b71ecefd Mon Sep 17 00:00:00 2001 From: Joseph Ware Date: Tue, 29 Aug 2023 15:04:13 +0100 Subject: [PATCH] Add dimensionality information to Specs - Closes #88 - Add axes, shape and snaked information in (usually) non-points-calculating method. - Deprecate Spec.axes: did not respect Frame information - Deprecate Spec.shape: Called calculate() and discarded information --- src/scanspec/core.py | 25 ++++++ src/scanspec/specs.py | 153 +++++++++++++++++++++++-------- tests/test_specs.py | 203 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 308 insertions(+), 73 deletions(-) diff --git a/src/scanspec/core.py b/src/scanspec/core.py index 3b1e015b..1d5877bf 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -11,6 +11,7 @@ List, Optional, Sequence, + Tuple, Type, TypeVar, Union, @@ -516,6 +517,30 @@ def squash_frames(stack: List[Frames[Axis]], check_path_changes=True) -> Frames[ return squashed +class DimensionInfo(Generic[Axis]): + def __init__( + self, + axes: Tuple[Tuple[Axis, ...]], + shape: Tuple[int, ...], + snaked: Tuple[bool, ...] = None, + ): + self._axes = axes + self._shape = shape + self._snaked = snaked or (False,) * len(shape) + + @property + def axes(self) -> Tuple[Tuple[Axis, ...]]: + return self._axes + + @property + def shape(self) -> Tuple[int, ...]: + return self._shape + + @property + def snaked(self) -> Tuple[bool, ...]: + return self._snaked + + class Path(Generic[Axis]): """A consumable route through a stack of Frames, representing a scan path. diff --git a/src/scanspec/specs.py b/src/scanspec/specs.py index 80c6a011..948e4879 100644 --- a/src/scanspec/specs.py +++ b/src/scanspec/specs.py @@ -1,7 +1,9 @@ from __future__ import annotations from dataclasses import asdict +from functools import reduce from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Tuple, Type +from warnings import warn import numpy as np from pydantic import Field, parse_obj_as @@ -9,6 +11,7 @@ from .core import ( Axis, + DimensionInfo, Frames, Midpoints, Path, @@ -60,7 +63,14 @@ def axes(self) -> List[Axis]: Ordered from slowest moving to fastest moving. """ - raise NotImplementedError(self) + warn( + "axes() is deprecated, call dimension_info()", + DeprecationWarning, + stacklevel=2, + ) + return reduce( + lambda a, b: a + list(b), self.dimension_info().axes, initial=list() + ) def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: """Produce a stack of nested `Frames` that form the scan. @@ -69,6 +79,16 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: """ raise NotImplementedError(self) + def dimension_info(self) -> DimensionInfo: + """Returns the list of axes in each dimension of the scan, + paired with the information on how large each dimension of the scan is, + and whether the dimension is snaked in the dimension outside it. + + Deprecates shape() as does not need to call calculate() + Deprecates axes() as has the per-dimension information + """ + raise NotImplementedError(self) + def frames(self) -> Frames[Axis]: """Expand all the scan `Frames` and return them.""" return Path(self.calculate()).consume() @@ -79,7 +99,12 @@ def midpoints(self) -> Midpoints[Axis]: def shape(self) -> Tuple[int, ...]: """Return the final, simplified shape of the scan.""" - return tuple(len(dim) for dim in self.calculate()) + warn( + "shape() is deprecated, call dimension_info()", + DeprecationWarning, + stacklevel=2, + ) + return self.dimension_info().shape def __rmul__(self, other) -> Product[Axis]: return if_instance_do(other, int, lambda o: Product(Repeat(o), self)) @@ -127,14 +152,20 @@ class Product(Spec[Axis]): outer: Spec[Axis] = Field(description="Will be executed once") inner: Spec[Axis] = Field(description="Will be executed len(outer) times") - def axes(self) -> List: - return self.outer.axes() + self.inner.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: frames_outer = self.outer.calculate(bounds=False, nested=nested) frames_inner = self.inner.calculate(bounds, nested=True) return frames_outer + frames_inner + def dimension_info(self) -> DimensionInfo: + outer_info = self.outer.dimension_info() + inner_info = self.inner.dimension_info() + return DimensionInfo( + axes=outer_info.axes + inner_info.axes, + shape=outer_info.shape + inner_info.shape, + snaked=outer_info.snaked + inner_info.snaked, + ) + @dataclass(config=StrictConfig) class Repeat(Spec[Axis]): @@ -166,12 +197,12 @@ class Repeat(Spec[Axis]): default=True, ) - def axes(self) -> List: - return [] - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: return [Frames({}, gap=np.full(self.num, self.gap))] + def dimension_info(self) -> DimensionInfo: + return DimensionInfo(axes=((DURATION,),), shape=(self.num,)) + @dataclass(config=StrictConfig) class Zip(Spec[Axis]): @@ -203,9 +234,6 @@ class Zip(Spec[Axis]): description="The right-hand Spec to Zip, will appear later in axes" ) - def axes(self) -> List: - return self.left.axes() + self.right.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: frames_left = self.left.calculate(bounds, nested) frames_right = self.right.calculate(bounds, nested) @@ -243,6 +271,24 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: frames.append(combined) return frames + def dimension_info(self) -> DimensionInfo: + left_info = self.left.dimension_info() + right_info = self.right.dimension_info() + left = left_info.axes + padded_right = ((None,),) * ( + len(left_info.axes) - len(right_info.axes) + ) + right_info.axes + axes = tuple( + left[i] if padded_right[i] == (None,) else left[i] + padded_right[i] + for i in range(len(left_info.axes)) + ) + + return DimensionInfo( + axes=axes, + shape=left_info.shape, + snaked=left_info.snaked, # Non-matching Snake axes cannot be Zipped + ) + @dataclass(config=StrictConfig) class Mask(Spec[Axis]): @@ -271,9 +317,6 @@ class Mask(Spec[Axis]): default=True, ) - def axes(self) -> List: - return self.spec.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: frames = self.spec.calculate(bounds, nested) for axis_set in self.region.axis_sets(): @@ -295,6 +338,21 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: masked_frames.append(f.extract(indices)) return masked_frames + def dimension_info(self) -> DimensionInfo: + """ + As Mask applies a Region to a Spec, which may alter the Spec, but this is not + knowable without calculating the entire Spec, we have to calculate the Spec + here. + Currently we throw away the results of this calculation, but in future we may + want to cache the result, or else modify the behaviour of this method generally + to match. + """ + frames = self.calculate(bounds=False, nested=False) + shape = tuple(len(x.midpoints) for x in frames) + axes = tuple(tuple(x.axes()) for x in frames) + snaked = tuple(isinstance(x, SnakedFrames) for x in frames) + return DimensionInfo(axes=axes, shape=shape, snaked=snaked) + # *+ bind more tightly than &|^ so without these overrides we # would need to add brackets to all combinations of Regions def __or__(self, other: Region[Axis]) -> Mask[Axis]: @@ -329,15 +387,21 @@ class Snake(Spec[Axis]): description="The Spec to run in reverse every other iteration" ) - def axes(self) -> List: - return self.spec.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: return [ SnakedFrames.from_frames(segment) for segment in self.spec.calculate(bounds, nested) ] + def dimension_info(self) -> DimensionInfo: + spec_info = self.spec.dimension_info() + return DimensionInfo( + axes=spec_info.axes, + shape=spec_info.shape, + snaked=(True,) * len(spec_info.shape), + ) + return self.spec.dimension_info() + @dataclass(config=StrictConfig) class Concat(Spec[Axis]): @@ -368,13 +432,6 @@ class Concat(Spec[Axis]): default=True, ) - def axes(self) -> List: - left_axes, right_axes = self.left.axes(), self.right.axes() - # Assuming the axes are the same, the order does not matter, we inherit the - # order from the left-hand side. See also scanspec.core.concat. - assert set(left_axes) == set(right_axes), f"axes {left_axes} != {right_axes}" - return left_axes - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: dim_left = squash_frames( self.left.calculate(bounds, nested), nested and self.check_path_changes @@ -385,6 +442,19 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: dim = dim_left.concat(dim_right, self.gap) return [dim] + def dimension_info(self) -> DimensionInfo: + left_info = self.left.dimension_info() + right_info = self.right.dimension_info() + assert left_info.axes == right_info.axes + # We will squash each spec into 1 dimension + left_size = reduce(lambda a, b: a * b, left_info.shape) + right_size = reduce(lambda a, b: a * b, right_info.shape) + return DimensionInfo( + axes=left_info.axes, + shape=(left_size + right_size,), + snaked=left_info.snaked, # Non-matching Snake axes cannot be Concat + ) + @dataclass(config=StrictConfig) class Squash(Spec[Axis]): @@ -406,14 +476,18 @@ class Squash(Spec[Axis]): default=True, ) - def axes(self) -> List: - return self.spec.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: dims = self.spec.calculate(bounds, nested) dim = squash_frames(dims, nested and self.check_path_changes) return [dim] + def dimension_info(self) -> DimensionInfo: + spec_info = self.spec.dimension_info() + return DimensionInfo( + axes=(reduce(lambda a, b: a + b, spec_info.axes),), + shape=(reduce(lambda a, b: a * b, spec_info.shape),), + ) + def _dimensions_from_indexes( func: Callable[[np.ndarray], Dict[Axis, np.ndarray]], @@ -458,8 +532,8 @@ class Line(Spec[Axis]): stop: float = Field(description="Midpoint of the last point of the line") num: int = Field(min=1, description="Number of frames to produce") - def axes(self) -> List: - return [self.axis] + def dimension_info(self) -> DimensionInfo: + return DimensionInfo(axes=((self.axis,),), shape=(self.num,)) def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: if self.num == 1: @@ -475,7 +549,7 @@ def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: return _dimensions_from_indexes( - self._line_from_indexes, self.axes(), self.num, bounds + self._line_from_indexes, [self.axis], self.num, bounds ) @classmethod @@ -538,17 +612,17 @@ def duration( """ return cls(DURATION, duration, num) - def axes(self) -> List: - return [self.axis] - def _repeats_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: return {self.axis: np.full(len(indexes), self.value)} def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: return _dimensions_from_indexes( - self._repeats_from_indexes, self.axes(), self.num, bounds + self._repeats_from_indexes, [self.axis], self.num, bounds ) + def dimension_info(self) -> DimensionInfo: + return DimensionInfo(axes=((self.axis,),), shape=(self.num,)) + @dataclass(config=StrictConfig) class Spiral(Spec[Axis]): @@ -577,9 +651,12 @@ class Spiral(Spec[Axis]): description="How much to rotate the angle of the spiral", default=0.0 ) - def axes(self) -> List[Axis]: - # TODO: reversed from __init__ args, a good idea? - return [self.y_axis, self.x_axis] + def dimension_info(self) -> DimensionInfo: + return DimensionInfo( + # TODO: reversed from __init__ args, a good idea? + axes=((self.y_axis, self.x_axis),), + shape=(self.num,), + ) def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: # simplest spiral equation: r = phi @@ -600,7 +677,7 @@ def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: return _dimensions_from_indexes( - self._spiral_from_indexes, self.axes(), self.num, bounds + self._spiral_from_indexes, [self.y_axis, self.x_axis], self.num, bounds ) @classmethod diff --git a/tests/test_specs.py b/tests/test_specs.py index 08b61d9b..b99d9ee4 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -60,7 +60,7 @@ def test_two_point_stepped_line() -> None: dimx, dimt = inst.calculate() assert dimx.midpoints == dimx.lower == dimx.upper == {x: pytest.approx([0, 1])} assert ( - dimt.midpoints == dimt.lower == dimt.upper == {DURATION: pytest.approx([0.1])} + dimt.midpoints == dimt.lower == dimt.upper == {DURATION: pytest.approx([0.1])} ) assert inst.frames().gap == ints("11") @@ -140,7 +140,7 @@ def test_spaced_spiral() -> None: def test_zipped_lines() -> None: inst = Line(x, 0, 1, 5).zip(Line(y, 1, 2, 5)) - assert inst.axes() == [x, y] + assert inst.dimension_info().axes == ((x, y),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([0, 0.25, 0.5, 0.75, 1]), @@ -149,11 +149,54 @@ def test_zipped_lines() -> None: assert dim.gap == ints("10000") +def test_zipped_snaked_lines() -> None: + inst = Line(x, 0, 1, 5).zip(~Line(y, 1, 2, 5)) + with pytest.raises(AssertionError) as ae: + inst.calculate() + assert ae.match("Mismatching types") + + +def test_zipped_both_snaked_lines() -> None: + inst = (~Line(x, 0, 1, 5)).zip(~Line(y, 1, 2, 5)) + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((x, y),) + assert dimension_info.snaked == (True,) + assert dimension_info.shape == (5,) + (dim,) = inst.calculate() + assert dim.midpoints == { + x: pytest.approx([0, 0.25, 0.5, 0.75, 1]), + y: pytest.approx([1, 1.25, 1.5, 1.75, 2]), + } + assert dim.gap == ints("00000") # Why is this 00000 not 10000? + + +def test_concat_snaked_lines() -> None: + inst = Line(y, 0, 1, 5).concat(~Line(y, 1, 2, 5)) + with pytest.raises(AssertionError) as ae: + inst.calculate() + assert ae.match("Mismatching types") + + +def test_concat_both_snaked_lines() -> None: + inst = (~Line(y, 0, 1, 5)).concat(~Line(y, 1, 2, 5)) + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((y,),) + assert dimension_info.snaked == (True,) + assert dimension_info.shape == (10,) + (dim,) = inst.calculate() + assert dim.midpoints == { + y: pytest.approx([0, 0.25, 0.5, 0.75, 1, 1, 1.25, 1.5, 1.75, 2]), + } + assert dim.gap == ints("0000010000") + + def test_product_lines() -> None: inst = Line(y, 1, 2, 3) * Line(x, 0, 1, 2) - assert inst.axes() == [y, x] + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((y,), (x,)) + assert dimension_info.shape == (3, 2) dims = inst.calculate() - assert len(dims) == 2 + assert len(dims) == 2 == len(dimension_info.shape) dim = Path(dims).consume() assert dim.midpoints == { x: pytest.approx([0, 1, 0, 1, 0, 1]), @@ -172,7 +215,7 @@ def test_product_lines() -> None: def test_zipped_product_lines() -> None: inst = Line(y, 1, 2, 3) * Line(x, 0, 1, 5).zip(Line(z, 2, 3, 5)) - assert inst.axes() == [y, x, z] + assert inst.dimension_info().axes == ((y,), (x, z)) dimy, dimxz = inst.calculate() assert dimxz.midpoints == { x: pytest.approx([0, 0.25, 0.5, 0.75, 1]), @@ -184,9 +227,78 @@ def test_zipped_product_lines() -> None: assert inst.frames().gap == ints("100001000010000") +def test_zipping_multiple_axes() -> None: + spiral = Spiral(x, y, 0, 10, 5, 50, 10) + spiral_midpoints = { + y: pytest.approx( + [5.4, 6.4, 19.7, 23.8, 15.4, 1.7, -8.6, -10.7, -4.1, 8.3], abs=0.1 + ), + x: pytest.approx( + [0.3, -0.9, -0.7, 0.5, 1.5, 1.6, 0.7, -0.6, -1.8, -2.4], abs=0.1 + ), + } + inst = spiral.zip(Line(z, 0, 9, 10)) + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((y, x, z),) # Spiral reverses y, x axes + assert dimension_info.shape == (10,) + (dimyxz,) = inst.calculate() + assert dimyxz.midpoints == { + **spiral_midpoints, + z: pytest.approx([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + } + assert inst.frames().gap == ints("1000000000") + + inst = spiral.zip(Line(z, 0, 0, 1)) + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((y, x, z),) + assert dimension_info.shape == (10,) + (dimyxz,) = inst.calculate() + assert dimyxz.midpoints == {**spiral_midpoints, z: pytest.approx([0] * 10)} + assert inst.frames().gap == ints("1000000000") + + +def test_zipping_higher_dimensionality() -> None: + # If dimensions(left) > dimensions(right), zip from innermost to outermost + grid = Line(y, 1, 2, 3) * Line(x, 0, 1, 2) + dimy_midpoints = { + y: pytest.approx([1, 1.5, 2]), + } + dimx_midpoints = { + x: pytest.approx([0, 1]), + } + + inst = grid.zip(Line(z, 0, 5, 2)) + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((y,), (x, z)) + assert dimension_info.shape == (3, 2) + dimy, dimxz = inst.calculate() + assert dimy.midpoints == dimy_midpoints + assert dimxz.midpoints == {**dimx_midpoints, z: pytest.approx([0, 5])} + + threed_grid = Line(z, 0, 5, 2) * grid + inst = threed_grid.zip(Line("p", 0, 5, 3) * Line("q", 0, 5, 2)) + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((z,), (y, "p"), (x, "q")) + assert dimension_info.shape == (2, 3, 2) + dimz, dimyp, dimxq = inst.calculate() + assert dimz.midpoints == {z: pytest.approx([0, 5])} + assert dimyp.midpoints == {**dimy_midpoints, "p": pytest.approx([0, 2.5, 5])} + assert dimxq.midpoints == {**dimx_midpoints, "q": pytest.approx([0, 5])} + + # If dimensions(right) == 1 and len(dimensions(right)[0]) == 1, + # dimensions(right)[0] *= len(dimensions(left)[-1] + inst = grid.zip(Line(z, 0, 0, 1)) + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((y,), (x, z)) + assert dimension_info.shape == (3, 2) + dimy, dimxz = inst.calculate() + assert dimy.midpoints == dimy_midpoints + assert dimxz.midpoints == {**dimx_midpoints, z: pytest.approx([0] * 2)} + + def test_squashed_product() -> None: inst = Squash(Line(y, 1, 2, 3) * Line(x, 0, 1, 2)) - assert inst.axes() == [y, x] + assert inst.dimension_info().axes == ((y, x),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([0, 1, 0, 1, 0, 1]), @@ -207,7 +319,9 @@ def test_squashed_multiplied_snake_scan() -> None: inst = Line(z, 1, 2, 2) * Squash( Line(y, 1, 2, 2) * ~Line.bounded(x, 3, 7, 2) * Static.duration(9, 2) ) - assert inst.axes() == [z, y, x, DURATION] + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((z,), (y, x, DURATION)) + assert dimension_info.snaked == (False, False) dimz, dimxyt = inst.calculate() for d in dimxyt.midpoints, dimxyt.lower, dimxyt.upper: assert d == { @@ -221,7 +335,9 @@ def test_squashed_multiplied_snake_scan() -> None: def test_product_snaking_lines() -> None: inst = Line(y, 1, 2, 3) * ~Line(x, 0, 1, 2) - assert inst.axes() == [y, x] + dimension_info = inst.dimension_info() + assert dimension_info.axes == ((y,), (x,)) + assert dimension_info.snaked == (False, True) dims = inst.calculate() assert len(dims) == 2 dim = Path(dims).consume() @@ -242,7 +358,7 @@ def test_product_snaking_lines() -> None: def test_concat_lines() -> None: inst = Concat(Line(x, 0, 1, 2), Line(x, 1, 2, 3)) - assert inst.axes() == [x] + assert inst.dimension_info().axes == ((x,),) (dim,) = inst.calculate() assert dim.midpoints == {x: pytest.approx([0, 1, 1, 1.5, 2])} assert dim.lower == {x: pytest.approx([-0.5, 0.5, 0.75, 1.25, 1.75])} @@ -252,7 +368,7 @@ def test_concat_lines() -> None: def test_rect_region() -> None: inst = Line(y, 1, 3, 5) * Line(x, 0, 2, 3) & Rectangle(x, y, 0, 1, 1.5, 2.2) - assert inst.axes() == [y, x] + assert inst.dimension_info().axes == ((y, x),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([0, 1, 0, 1, 0, 1]), @@ -273,7 +389,7 @@ def test_rect_region_3D() -> None: inst = Static(z, 3.2, 2) * Line(y, 1, 3, 5) * Line(x, 0, 2, 3) & Rectangle( x, y, 0, 1, 1.5, 2.2 ) - assert inst.axes() == [z, y, x] + assert inst.dimension_info().axes == ((z,), (y, x)) zdim, xydim = inst.calculate() assert zdim.midpoints == {z: pytest.approx([3.2, 3.2])} assert zdim.midpoints is zdim.upper @@ -297,7 +413,7 @@ def test_rect_region_union() -> None: inst = Line(y, 1, 3, 5) * Line(x, 0, 2, 3) & Rectangle( x, y, 0, 1, 1.5, 2.2 ) | Rectangle(x, y, 0.5, 1.5, 2, 2.5) - assert inst.axes() == [y, x] + assert inst.dimension_info().axes == ((y, x),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([0, 1, 0, 1, 2, 0, 1, 2, 1, 2]), @@ -308,11 +424,11 @@ def test_rect_region_union() -> None: def test_rect_region_intersection() -> None: inst = ( - Line(y, 1, 3, 5) * Line(x, 0, 2, 3) - & Rectangle(x, y, 0, 1, 1.5, 2.2) - & Rectangle(x, y, 0.5, 1.5, 2, 2.5) + Line(y, 1, 3, 5) * Line(x, 0, 2, 3) + & Rectangle(x, y, 0, 1, 1.5, 2.2) + & Rectangle(x, y, 0.5, 1.5, 2, 2.5) ) - assert inst.axes() == [y, x] + assert inst.dimension_info().axes == ((y, x),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([1, 1]), @@ -324,10 +440,10 @@ def test_rect_region_intersection() -> None: def test_rect_region_difference() -> None: # Bracket to force testing Mask.__sub__ rather than Region.__sub__ inst = ( - Line(y, 1, 3, 5) * Line(x, 0, 2, 3).zip(Static(DURATION, 0.1)) - & Rectangle(x, y, 0, 1, 1.5, 2.2) - ) - Rectangle(x, y, 0.5, 1.5, 2, 2.5) - assert inst.axes() == [y, x, DURATION] + Line(y, 1, 3, 5) * Line(x, 0, 2, 3).zip(Static(DURATION, 0.1)) + & Rectangle(x, y, 0, 1, 1.5, 2.2) + ) - Rectangle(x, y, 0.5, 1.5, 2, 2.5) + assert inst.dimension_info().axes == ((y, x, DURATION),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([0, 1, 0, 0]), @@ -341,7 +457,7 @@ def test_rect_region_symmetricdifference() -> None: inst = Line(y, 1, 3, 5) * Line(x, 0, 2, 3) & Rectangle( x, y, 0, 1, 1.5, 2.2 ) ^ Rectangle(x, y, 0.5, 1.5, 2, 2.5) - assert inst.axes() == [y, x] + assert inst.dimension_info().axes == ((y, x),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([0, 1, 0, 2, 0, 2, 1, 2]), @@ -352,7 +468,7 @@ def test_rect_region_symmetricdifference() -> None: def test_circle_region() -> None: inst = Line(y, 1, 3, 3) * Line(x, 0, 2, 3) & Circle(x, y, 1, 2, 1) - assert inst.axes() == [y, x] + assert inst.dimension_info().axes == ((y, x),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([1, 0, 1, 2, 1]), @@ -375,7 +491,7 @@ def test_circle_snaked_region() -> None: Circle(x, y, 1, 2, 1), check_path_changes=False, ) - assert inst.axes() == [y, x] + assert inst.dimension_info().axes == ((y, x),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([1, 2, 1, 0, 1]), @@ -394,7 +510,7 @@ def test_circle_snaked_region() -> None: def test_ellipse_region() -> None: inst = Line("y", 1, 3, 3) * Line("x", 0, 2, 3) & Ellipse(x, y, 1, 2, 2, 1, 45) - assert inst.axes() == [y, x] + assert inst.dimension_info().axes == ((y, x),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([0, 1, 0, 1, 2, 1, 2]), @@ -415,7 +531,7 @@ def test_polygon_region() -> None: x_verts = [0, 0.5, 4.0, 2.5] y_verts = [0, 3.5, 3.5, 0.5] inst = Line("y", 1, 3, 3) * Line("x", 0, 4, 5) & Polygon(x, y, x_verts, y_verts) - assert inst.axes() == [y, x] + assert inst.dimension_info().axes == ((y, x),) (dim,) = inst.calculate() assert dim.midpoints == { x: pytest.approx([1, 2, 1, 2, 3, 1, 2, 3]), @@ -436,6 +552,13 @@ def test_xyz_stack() -> None: # Beam selector scan moves bounded between midpoints and lower and upper bounds at # maximum speed. Turnaround sections are where it sends the triggers spec = Line(z, 0, 1, 2) * ~Line(y, 0, 2, 3) * ~Line(x, 0, 3, 4) + info = spec.dimension_info() + assert info.axes == ( + (z,), + (y,), + (x,), + ) + assert info.shape == (2, 3, 4) dim = spec.frames() assert len(dim) == 24 assert dim.lower == { @@ -469,6 +592,9 @@ def test_beam_selector() -> None: # Beam selector scan moves bounded between midpoints and lower and upper bounds at # maximum speed. Turnaround sections are where it sends the triggers spec = 10 * ~Line.bounded(x, 11, 19, 1) + info = spec.dimension_info() + assert info.axes == ((DURATION,), (x,)) + assert info.shape == (10, 1) dim = spec.frames() assert len(dim) == 10 assert dim.lower == {x: pytest.approx([11, 19, 11, 19, 11, 19, 11, 19, 11, 19])} @@ -531,30 +657,37 @@ def test_multiple_statics_with_grid(): @pytest.mark.parametrize( - "spec,expected_shape", + "spec,expected_shape,expected_axes", [ - (Line("x", 0.0, 1.0, 1), (1,)), - (Line("x", 0.0, 1.0, 5), (5,)), - (Spiral("x", "y", 0.0, 0.0, 1.0, 1.0, 5, 0.0), (5,)), - (Line("x", 0.0, 1.0, 2) * Line("y", 0.0, 1.0, 2), (2, 2)), - (Squash(Line("x", 0.0, 1.0, 2) * Line("y", 0.0, 1.0, 2)), (4,)), - (Zip(Line("x", 0.0, 1.0, 2), Line("y", 0.0, 1.0, 2)), (2,)), - (Concat(Line("x", 0.0, 1.0, 2), Line("x", 0.0, 1.0, 2)), (4,)), + (Line("x", 0.0, 1.0, 1), (1,), (("x",),)), + (Line("x", 0.0, 1.0, 5), (5,), (("x",),)), + (Spiral("x", "y", 0.0, 0.0, 1.0, 1.0, 5, 0.0), (5,), (("y", "x"),)), + (Line("x", 0.0, 1.0, 2) * Line("y", 0.0, 1.0, 2), (2, 2), (("x",), ("y",))), + (Squash(Line("x", 0.0, 1.0, 2) * Line("y", 0.0, 1.0, 2)), (4,), (("x", "y"),)), + (Zip(Line("x", 0.0, 1.0, 2), Line("y", 0.0, 1.0, 2)), (2,), (("x", "y"),)), + (Concat(Line("x", 0.0, 1.0, 2), Line("x", 0.0, 1.0, 2)), (4,), (("x",),)), ( Line("x", 0.0, 1.0, 2) * Line("y", 0.0, 1.0, 2) * Line("z", 0.0, 2.0, 2), (2, 2, 2), + (("x",), ("y",), ("z",)), ), ( Zip(Line("x", 0.0, 1.0, 2), Line("y", 0.0, 1.0, 2)) * Line("z", 0.0, 2.0, 2), (2, 2), + (("x", "y"), ("z",)), ), ( Concat(Line("x", 0.0, 1.0, 2), Line("x", 0.0, 1.0, 2)) * Line("z", 0.0, 2.0, 2), (4, 2), + (("x",), ("z",)), ), ], ) -def test_shape(spec: Spec, expected_shape: Tuple[int, ...]): - assert expected_shape == spec.shape() +def test_dimension_info( + spec: Spec, expected_shape: Tuple[int, ...], expected_axes: Tuple[Tuple[str, ...]] +): + dimension_info = spec.dimension_info() + assert expected_shape == dimension_info.shape + assert expected_axes == dimension_info.axes