Skip to content

Commit

Permalink
Refactor Frames to require _native_world_axis_object_components
Browse files Browse the repository at this point in the history
This means that world_axis_object_components can be automatically sorted
for all frames.
  • Loading branch information
Cadair committed Sep 26, 2024
1 parent ccbc5e9 commit 18bb8ef
Showing 1 changed file with 99 additions and 82 deletions.
181 changes: 99 additions & 82 deletions gwcs/coordinate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,81 @@ def get_ctype_from_ucd(ucd):
return UCD1_TO_CTYPE.get(ucd, "")


@dataclass
class FrameProperties:
naxes: InitVar[int]
axes_type: tuple[str]
unit: tuple[u.Unit] = None
axes_names: tuple[str] = None
axis_physical_types: list[str] = None

def __post_init__(self, naxes):
if isinstance(self.axes_type, str):
self.axes_type = (self.axes_type,)

Check warning on line 190 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L190

Added line #L190 was not covered by tests
else:
self.axes_type = tuple(self.axes_type)

if len(self.axes_type) != naxes:
raise ValueError("Length of axes_type does not match number of axes.")

if self.unit is not None:
if astutil.isiterable(self.unit):
unit = tuple(self.unit)
else:
unit = (self.unit,)
if len(unit) != naxes:
raise ValueError("Number of units does not match number of axes.")
else:
self.unit = tuple(u.Unit(au) for au in unit)
else:
self.unit = tuple(u.dimensionless_unscaled for na in range(naxes))

if self.axes_names is not None:
if isinstance(self.axes_names, str):
self.axes_names = (self.axes_names,)
else:
self.axes_names = tuple(self.axes_names)
if len(self.axes_names) != naxes:
raise ValueError("Number of axes names does not match number of axes.")
else:
self.axes_names = tuple([""] * naxes)

if self.axis_physical_types is not None:
if isinstance(self.axis_physical_types, str):
self.axis_physical_types = (self.axis_physical_types,)
elif not isiterable(self.axis_physical_types):
raise TypeError("axis_physical_types must be of type string or iterable of strings")

Check warning on line 223 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L223

Added line #L223 was not covered by tests
if len(self.axis_physical_types) != naxes:
raise ValueError(f'"axis_physical_types" must be of length {naxes}')
ph_type = []
for axt in self.axis_physical_types:
if axt not in VALID_UCDS and not axt.startswith("custom:"):
ph_type.append(f"custom:{axt}")
else:
ph_type.append(axt)

validate_physical_types(ph_type)
self.axis_physical_types = tuple(ph_type)

@property
def _default_axis_physical_type(self):
"""
The default physical types to use for this frame if none are specified
by the user.
"""
return tuple("custom:{}".format(t) for t in self.axes_type)

Check warning on line 242 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L242

Added line #L242 was not covered by tests


class BaseCoordinateFrame(abc.ABC):
"""
API Definition for a Coordinate frame
"""

_prop: FrameProperties
"""
The FrameProperties object holding properties in native frame order.
"""

@property
@abc.abstractmethod
def naxes(self) -> int:
Expand Down Expand Up @@ -253,7 +324,6 @@ def world_axis_object_classes(self):
"""

@property
@abc.abstractmethod
def world_axis_object_components(self):
"""
The APE 14 object components for this frame.
Expand All @@ -262,71 +332,30 @@ def world_axis_object_components(self):
--------
astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components
"""
if self.naxes == 1:
return self._native_world_axis_object_components


@dataclass
class FrameProperties:
naxes: InitVar[int]
axes_type: tuple[str]
unit: tuple[u.Unit] = None
axes_names: tuple[str] = None
axis_physical_types: list[str] = None

def __post_init__(self, naxes):
if isinstance(self.axes_type, str):
self.axes_type = (self.axes_type,)
else:
self.axes_type = tuple(self.axes_type)

if len(self.axes_type) != naxes:
raise ValueError("Length of axes_type does not match number of axes.")

if self.unit is not None:
if astutil.isiterable(self.unit):
unit = tuple(self.unit)
else:
unit = (self.unit,)
if len(unit) != naxes:
raise ValueError("Number of units does not match number of axes.")
else:
self.unit = tuple(u.Unit(au) for au in unit)
else:
self.unit = tuple(u.dimensionless_unscaled for na in range(naxes))

if self.axes_names is not None:
if isinstance(self.axes_names, str):
self.axes_names = (self.axes_names,)
else:
self.axes_names = tuple(self.axes_names)
if len(self.axes_names) != naxes:
raise ValueError("Number of axes names does not match number of axes.")
else:
self.axes_names = tuple([""] * naxes)

if self.axis_physical_types is not None:
if isinstance(self.axis_physical_types, str):
self.axis_physical_types = (self.axis_physical_types,)
elif not isiterable(self.axis_physical_types):
raise TypeError("axis_physical_types must be of type string or iterable of strings")
if len(self.axis_physical_types) != naxes:
raise ValueError(f'"axis_physical_types" must be of length {naxes}')
ph_type = []
for axt in self.axis_physical_types:
if axt not in VALID_UCDS and not axt.startswith("custom:"):
ph_type.append(f"custom:{axt}")
else:
ph_type.append(axt)

validate_physical_types(ph_type)
self.axis_physical_types = tuple(ph_type)
# If we have more than one axis then we should sort the native
# components by the axes_order.
ordered = np.array(self._native_world_axis_object_components,
dtype=object)[np.argsort(self.axes_order)]
return list(map(tuple, ordered))

@property
def _default_axis_physical_type(self):
@abc.abstractmethod
def _native_world_axis_object_components(self):
"""
The default physical types to use for this frame if none are specified
by the user.
This property holds the "native" frame order of the components.
The native order of the componets is the order the frame assumes the
axes are in when creating the high level objects, for example
``CelestialFrame`` creates ``SkyCoord`` objects which are in lon, lat
order (in their positional args).
This property is used both to construct the ordered
``world_axis_object_components`` property as well as by `CompositeFrame`
to be able to get the components in their native order.
"""
return tuple("custom:{}".format(t) for t in self.axes_type)


class CoordinateFrame(BaseCoordinateFrame):
Expand Down Expand Up @@ -401,8 +430,6 @@ def __str__(self):
return self.__class__.__name__

def _sort_property(self, property):
#return tuple(dict(sorted(zip(property, self.axes_order),
# key=lambda x: x[1])).keys())
sorted_prop = sorted(zip(property, self.axes_order),
key=lambda x: x[1])
return tuple([t[0] for t in sorted_prop])
Expand All @@ -426,7 +453,6 @@ def naxes(self):
def unit(self):
"""The unit of this frame."""
return self._sort_property(self._prop.unit)
#return self._prop.unit

@property
def axes_names(self):
Expand Down Expand Up @@ -464,19 +490,18 @@ def world_axis_object_classes(self):
{'unit': unit})
for i, (at, unit) in enumerate(zip(self.axes_type, self.unit))}

@property
def world_axis_object_components(self):
return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)]

@property
def _native_world_axis_object_components(self):
"""Defines the target component ordering (i.e. not taking into account axes_order)"""
return self.world_axis_object_components
return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)]


class CelestialFrame(CoordinateFrame):
"""
Celestial Frame Representation
Representation of a Celesital coordinate system.
This class has a native order of longitude then latitude, meaning
``axes_names``, ``unit`` should be lon, lat ordered. If your transform is
in a different order this should be specified with ``axes_order``.
Parameters
----------
Expand Down Expand Up @@ -551,14 +576,6 @@ def _native_world_axis_object_components(self):
return [('celestial', 0, 'spherical.lon'),
('celestial', 1, 'spherical.lat')]

@property
def world_axis_object_components(self):
# Sort the native waoc by the axes order. The axes order may have jumps
# in it if there are other frames in between the components.
ordered = np.array(self._native_world_axis_object_components,
dtype=object)[np.argsort(self.axes_order)]
return list(map(tuple, ordered))


class SpectralFrame(CoordinateFrame):
"""
Expand Down Expand Up @@ -616,7 +633,7 @@ def world_axis_object_classes(self):
{'unit': self.unit[0]})}

@property
def world_axis_object_components(self):
def _native_world_axis_object_components(self):
return [('spectral', 0, 'value')]


Expand Down Expand Up @@ -685,7 +702,7 @@ def world_axis_object_classes(self):
return {'temporal': comp}

@property
def world_axis_object_components(self):
def _native_world_axis_object_components(self):
if isinstance(self.reference_frame.value, np.ndarray):
return [('temporal', 0, 'value')]

Expand Down Expand Up @@ -859,7 +876,7 @@ def world_axis_object_classes(self):
)}

@property
def world_axis_object_components(self):
def _native_world_axis_object_components(self):
return [('stokes', 0, 'value')]


Expand Down

0 comments on commit 18bb8ef

Please sign in to comment.