Skip to content

Commit

Permalink
Add dimensionality information to Specs
Browse files Browse the repository at this point in the history
- Closes #88
- Add axes, shape and snaked information in (usually) non-points-calculating method.
- Deprecate Spec.axes: did not respect Frame information
- Deprecate Spec.shape: Called calculate() and discarded information
  • Loading branch information
DiamondJoseph committed Aug 30, 2023
1 parent de99a51 commit f5fca18
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 73 deletions.
25 changes: 25 additions & 0 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -516,6 +517,30 @@ def squash_frames(stack: List[Frames[Axis]], check_path_changes=True) -> Frames[
return squashed


class DimensionInfo(Generic[Axis]):
def __init__(
self,
axes: Tuple[Tuple[Axis, ...]],
shape: Tuple[int, ...],
snaked: Tuple[bool, ...] = None,
):
self._axes = axes
self._shape = shape
self._snaked = snaked or (False,) * len(shape)

@property
def axes(self) -> Tuple[Tuple[Axis, ...]]:
return self._axes

@property
def shape(self) -> Tuple[int, ...]:
return self._shape

@property
def snaked(self) -> Tuple[bool, ...]:
return self._snaked


class Path(Generic[Axis]):
"""A consumable route through a stack of Frames, representing a scan path.
Expand Down
153 changes: 115 additions & 38 deletions src/scanspec/specs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

from dataclasses import asdict
from functools import reduce
from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Tuple, Type
from warnings import warn

import numpy as np
from pydantic import Field, parse_obj_as
from pydantic.dataclasses import dataclass

from .core import (
Axis,
DimensionInfo,
Frames,
Midpoints,
Path,
Expand Down Expand Up @@ -60,7 +63,14 @@ def axes(self) -> List[Axis]:
Ordered from slowest moving to fastest moving.
"""
raise NotImplementedError(self)
warn(
"axes() is deprecated, call dimension_info()",
DeprecationWarning,
stacklevel=2,
)
return reduce(
lambda a, b: a + list(b), self.dimension_info().axes, initial=list()
)

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
"""Produce a stack of nested `Frames` that form the scan.
Expand All @@ -69,6 +79,16 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
"""
raise NotImplementedError(self)

def dimension_info(self) -> DimensionInfo:
"""Returns the list of axes in each dimension of the scan,
paired with the information on how large each dimension of the scan is,
and whether the dimension is snaked in the dimension outside it.
Deprecates shape() as does not need to call calculate()
Deprecates axes() as has the per-dimension information
"""
raise NotImplementedError(self)

def frames(self) -> Frames[Axis]:
"""Expand all the scan `Frames` and return them."""
return Path(self.calculate()).consume()
Expand All @@ -79,7 +99,12 @@ def midpoints(self) -> Midpoints[Axis]:

def shape(self) -> Tuple[int, ...]:
"""Return the final, simplified shape of the scan."""
return tuple(len(dim) for dim in self.calculate())
warn(
"shape() is deprecated, call dimension_info()",
DeprecationWarning,
stacklevel=2,
)
return self.dimension_info().shape

def __rmul__(self, other) -> Product[Axis]:
return if_instance_do(other, int, lambda o: Product(Repeat(o), self))
Expand Down Expand Up @@ -127,14 +152,20 @@ class Product(Spec[Axis]):
outer: Spec[Axis] = Field(description="Will be executed once")
inner: Spec[Axis] = Field(description="Will be executed len(outer) times")

def axes(self) -> List:
return self.outer.axes() + self.inner.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
frames_outer = self.outer.calculate(bounds=False, nested=nested)
frames_inner = self.inner.calculate(bounds, nested=True)
return frames_outer + frames_inner

def dimension_info(self) -> DimensionInfo:
outer_info = self.outer.dimension_info()
inner_info = self.inner.dimension_info()
return DimensionInfo(
axes=outer_info.axes + inner_info.axes,
shape=outer_info.shape + inner_info.shape,
snaked=outer_info.snaked + inner_info.snaked,
)


@dataclass(config=StrictConfig)
class Repeat(Spec[Axis]):
Expand Down Expand Up @@ -166,12 +197,12 @@ class Repeat(Spec[Axis]):
default=True,
)

def axes(self) -> List:
return []

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return [Frames({}, gap=np.full(self.num, self.gap))]

def dimension_info(self) -> DimensionInfo:
return DimensionInfo(axes=((DURATION,),), shape=(self.num,))


@dataclass(config=StrictConfig)
class Zip(Spec[Axis]):
Expand Down Expand Up @@ -203,9 +234,6 @@ class Zip(Spec[Axis]):
description="The right-hand Spec to Zip, will appear later in axes"
)

def axes(self) -> List:
return self.left.axes() + self.right.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
frames_left = self.left.calculate(bounds, nested)
frames_right = self.right.calculate(bounds, nested)
Expand Down Expand Up @@ -243,6 +271,24 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
frames.append(combined)
return frames

def dimension_info(self) -> DimensionInfo:
left_info = self.left.dimension_info()
right_info = self.right.dimension_info()
left = left_info.axes
padded_right = ((None,),) * (
len(left_info.axes) - len(right_info.axes)
) + right_info.axes
axes = tuple(
left[i] if padded_right[i] == (None,) else left[i] + padded_right[i]
for i in range(len(left_info.axes))
)

return DimensionInfo(
axes=axes,
shape=left_info.shape,
snaked=left_info.snaked, # Non-matching Snake axes cannot be Zipped
)


@dataclass(config=StrictConfig)
class Mask(Spec[Axis]):
Expand Down Expand Up @@ -271,9 +317,6 @@ class Mask(Spec[Axis]):
default=True,
)

def axes(self) -> List:
return self.spec.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
frames = self.spec.calculate(bounds, nested)
for axis_set in self.region.axis_sets():
Expand All @@ -295,6 +338,21 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
masked_frames.append(f.extract(indices))
return masked_frames

def dimension_info(self) -> DimensionInfo:
"""
As Mask applies a Region to a Spec, which may alter the Spec, but this is not
knowable without calculating the entire Spec, we have to calculate the Spec
here.
Currently we throw away the results of this calculation, but in future we may
want to cache the result, or else modify the behaviour of this method generally
to match.
"""
frames = self.calculate(bounds=False, nested=False)
shape = tuple(len(x.midpoints) for x in frames)
axes = tuple(tuple(x.axes()) for x in frames)
snaked = tuple(isinstance(x, SnakedFrames) for x in frames)
return DimensionInfo(axes=axes, shape=shape, snaked=snaked)

# *+ bind more tightly than &|^ so without these overrides we
# would need to add brackets to all combinations of Regions
def __or__(self, other: Region[Axis]) -> Mask[Axis]:
Expand Down Expand Up @@ -329,15 +387,21 @@ class Snake(Spec[Axis]):
description="The Spec to run in reverse every other iteration"
)

def axes(self) -> List:
return self.spec.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return [
SnakedFrames.from_frames(segment)
for segment in self.spec.calculate(bounds, nested)
]

def dimension_info(self) -> DimensionInfo:
spec_info = self.spec.dimension_info()
return DimensionInfo(
axes=spec_info.axes,
shape=spec_info.shape,
snaked=(True,) * len(spec_info.shape),
)
return self.spec.dimension_info()


@dataclass(config=StrictConfig)
class Concat(Spec[Axis]):
Expand Down Expand Up @@ -368,13 +432,6 @@ class Concat(Spec[Axis]):
default=True,
)

def axes(self) -> List:
left_axes, right_axes = self.left.axes(), self.right.axes()
# Assuming the axes are the same, the order does not matter, we inherit the
# order from the left-hand side. See also scanspec.core.concat.
assert set(left_axes) == set(right_axes), f"axes {left_axes} != {right_axes}"
return left_axes

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
dim_left = squash_frames(
self.left.calculate(bounds, nested), nested and self.check_path_changes
Expand All @@ -385,6 +442,19 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
dim = dim_left.concat(dim_right, self.gap)
return [dim]

def dimension_info(self) -> DimensionInfo:
left_info = self.left.dimension_info()
right_info = self.right.dimension_info()
assert left_info.axes == right_info.axes
# We will squash each spec into 1 dimension
left_size = reduce(lambda a, b: a * b, left_info.shape)
right_size = reduce(lambda a, b: a * b, right_info.shape)
return DimensionInfo(
axes=left_info.axes,
shape=(left_size + right_size,),
snaked=left_info.snaked, # Non-matching Snake axes cannot be Concat
)


@dataclass(config=StrictConfig)
class Squash(Spec[Axis]):
Expand All @@ -406,14 +476,18 @@ class Squash(Spec[Axis]):
default=True,
)

def axes(self) -> List:
return self.spec.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
dims = self.spec.calculate(bounds, nested)
dim = squash_frames(dims, nested and self.check_path_changes)
return [dim]

def dimension_info(self) -> DimensionInfo:
spec_info = self.spec.dimension_info()
return DimensionInfo(
axes=(reduce(lambda a, b: a + b, spec_info.axes),),
shape=(reduce(lambda a, b: a * b, spec_info.shape),),
)


def _dimensions_from_indexes(
func: Callable[[np.ndarray], Dict[Axis, np.ndarray]],
Expand Down Expand Up @@ -458,8 +532,8 @@ class Line(Spec[Axis]):
stop: float = Field(description="Midpoint of the last point of the line")
num: int = Field(min=1, description="Number of frames to produce")

def axes(self) -> List:
return [self.axis]
def dimension_info(self) -> DimensionInfo:
return DimensionInfo(axes=((self.axis,),), shape=(self.num,))

def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
if self.num == 1:
Expand All @@ -475,7 +549,7 @@ def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return _dimensions_from_indexes(
self._line_from_indexes, self.axes(), self.num, bounds
self._line_from_indexes, [self.axis], self.num, bounds
)

@classmethod
Expand Down Expand Up @@ -538,17 +612,17 @@ def duration(
"""
return cls(DURATION, duration, num)

def axes(self) -> List:
return [self.axis]

def _repeats_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
return {self.axis: np.full(len(indexes), self.value)}

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return _dimensions_from_indexes(
self._repeats_from_indexes, self.axes(), self.num, bounds
self._repeats_from_indexes, [self.axis], self.num, bounds
)

def dimension_info(self) -> DimensionInfo:
return DimensionInfo(axes=((self.axis,),), shape=(self.num,))


@dataclass(config=StrictConfig)
class Spiral(Spec[Axis]):
Expand Down Expand Up @@ -577,9 +651,12 @@ class Spiral(Spec[Axis]):
description="How much to rotate the angle of the spiral", default=0.0
)

def axes(self) -> List[Axis]:
# TODO: reversed from __init__ args, a good idea?
return [self.y_axis, self.x_axis]
def dimension_info(self) -> DimensionInfo:
return DimensionInfo(
# TODO: reversed from __init__ args, a good idea?
axes=((self.y_axis, self.x_axis),),
shape=(self.num,),
)

def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
# simplest spiral equation: r = phi
Expand All @@ -600,7 +677,7 @@ def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return _dimensions_from_indexes(
self._spiral_from_indexes, self.axes(), self.num, bounds
self._spiral_from_indexes, [self.y_axis, self.x_axis], self.num, bounds
)

@classmethod
Expand Down
Loading

0 comments on commit f5fca18

Please sign in to comment.