Skip to content

Commit

Permalink
refactor[next]: NamedRange/NamedIndex tuple to NamedTuple (#1490)
Browse files Browse the repository at this point in the history
Change NamedRange and NamedIndex from being a plain tuple to a
NamedTuple for cleaner element access.

---------

Co-authored-by: Hannes Vogt <[email protected]>
Co-authored-by: Enrique González Paredes <[email protected]>
  • Loading branch information
3 people authored Mar 21, 2024
1 parent 879c836 commit e344afd
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 213 deletions.
122 changes: 60 additions & 62 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ClassVar,
Final,
Generic,
NamedTuple,
Never,
Optional,
ParamSpec,
Expand Down Expand Up @@ -84,7 +85,7 @@ def __str__(self) -> str:
return f"{self.value}[{self.kind}]"

def __call__(self, val: int) -> NamedIndex:
return self, val
return NamedIndex(self, val)


class Infinity(enum.Enum):
Expand Down Expand Up @@ -248,9 +249,16 @@ def __str__(self) -> str:

FiniteUnitRange: TypeAlias = UnitRange[int, int]

_Rng = TypeVar(
"_Rng",
FiniteUnitRange,
UnitRange[Infinity, int],
UnitRange[int, Infinity],
UnitRange[Infinity, Infinity],
)

RangeLike: TypeAlias = (
UnitRange
_Rng
| range
| tuple[core_defs.IntegralScalar, core_defs.IntegralScalar]
| core_defs.IntegralScalar
Expand Down Expand Up @@ -282,10 +290,26 @@ def unit_range(r: RangeLike) -> UnitRange:
raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.")


class NamedRange(NamedTuple, Generic[_Rng]):
dim: Dimension
unit_range: _Rng

def __str__(self) -> str:
return f"{self.dim}={self.unit_range}"


IntIndex: TypeAlias = int | core_defs.IntegralScalar
NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple
NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple
FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple


class NamedIndex(NamedTuple):
dim: Dimension
value: IntIndex

def __str__(self) -> str:
return f"{self.dim}={self.value}"


FiniteNamedRange: TypeAlias = NamedRange[FiniteUnitRange]
RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType
NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice
Expand All @@ -304,41 +328,22 @@ def is_int_index(p: Any) -> TypeGuard[IntIndex]:
return isinstance(p, (int, core_defs.INTEGRAL_TYPES))


def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
return (
isinstance(v, tuple)
and len(v) == 2
and isinstance(v[0], Dimension)
and isinstance(v[1], UnitRange)
)


def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]:
return UnitRange.is_finite(v[1])
return UnitRange.is_finite(v.unit_range)


def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
return (
isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1])
def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[slice]:
return isinstance(obj, slice) and (
isinstance(obj.start, NamedIndex) and isinstance(obj.stop, NamedIndex)
)


def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]:
return isinstance(obj, slice) and (is_named_index(obj.start) and is_named_index(obj.stop))


def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]:
return (
is_int_index(v)
or is_named_range(v)
or is_named_index(v)
or isinstance(v, slice)
or v is Ellipsis
)
return is_int_index(v) or isinstance(v, (NamedRange, NamedIndex, slice)) or v is Ellipsis


def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]:
return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v)
return isinstance(v, Sequence) and all(isinstance(e, (NamedRange, NamedIndex)) for e in v)


def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]:
Expand All @@ -356,28 +361,21 @@ def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence:


def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange:
return (v[0], unit_range(v[1]))


_Rng = TypeVar(
"_Rng",
UnitRange[int, int],
UnitRange[Infinity, int],
UnitRange[int, Infinity],
UnitRange[Infinity, Infinity],
)
if isinstance(v, NamedRange):
return v
return NamedRange(v[0], unit_range(v[1]))


@dataclasses.dataclass(frozen=True, init=False)
class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]):
class Domain(Sequence[NamedRange[_Rng]], Generic[_Rng]):
"""Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s."""

dims: tuple[Dimension, ...]
ranges: tuple[_Rng, ...]

def __init__(
self,
*args: tuple[Dimension, _Rng],
*args: NamedRange[_Rng],
dims: Optional[Sequence[Dimension]] = None,
ranges: Optional[Sequence[_Rng]] = None,
) -> None:
Expand Down Expand Up @@ -406,7 +404,7 @@ def __init__(
object.__setattr__(self, "dims", tuple(dims))
object.__setattr__(self, "ranges", tuple(ranges))
else:
if not all(is_named_range(arg) for arg in args):
if not all(isinstance(arg, NamedRange) for arg in args):
raise ValueError(
f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'."
)
Expand Down Expand Up @@ -437,25 +435,25 @@ def is_empty(self) -> bool:
return any(rng.is_empty() for rng in self.ranges)

@overload
def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ...
def __getitem__(self, index: int) -> NamedRange: ...

@overload
def __getitem__(self, index: slice) -> Self: ...

@overload
def __getitem__(self, index: Dimension) -> tuple[Dimension, _Rng]: ...
def __getitem__(self, index: Dimension) -> NamedRange: ...

def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain:
if isinstance(index, int):
return self.dims[index], self.ranges[index]
return NamedRange(dim=self.dims[index], unit_range=self.ranges[index])
elif isinstance(index, slice):
dims_slice = self.dims[index]
ranges_slice = self.ranges[index]
return Domain(dims=dims_slice, ranges=ranges_slice)
elif isinstance(index, Dimension):
try:
index_pos = self.dims.index(index)
return self.dims[index_pos], self.ranges[index_pos]
return NamedRange(dim=self.dims[index_pos], unit_range=self.ranges[index_pos])
except ValueError as ex:
raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex
else:
Expand All @@ -470,10 +468,12 @@ def __and__(self, other: Domain) -> Domain:
>>> I = Dimension("I")
>>> J = Dimension("J")
>>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6)))
>>> Domain(NamedRange(I, UnitRange(-1, 3))) & Domain(NamedRange(I, UnitRange(1, 6)))
Domain(dims=(Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>),), ranges=(UnitRange(1, 3),))
>>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6)))
>>> Domain(NamedRange(I, UnitRange(-1, 3)), NamedRange(J, UnitRange(2, 4))) & Domain(
... NamedRange(I, UnitRange(1, 6))
... )
Domain(dims=(Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)), ranges=(UnitRange(1, 3), UnitRange(2, 4)))
"""
broadcast_dims = tuple(promote_dims(self.dims, other.dims))
Expand All @@ -487,7 +487,7 @@ def __and__(self, other: Domain) -> Domain:
return Domain(dims=broadcast_dims, ranges=intersected_ranges)

def __str__(self) -> str:
return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})"
return f"Domain({', '.join(f'{e}' for e in self)})"

def dim_index(self, dim: Dimension) -> Optional[int]:
return self.dims.index(dim) if dim in self.dims else None
Expand All @@ -503,7 +503,7 @@ def insert(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:
return self.replace(index, *named_ranges)

def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:
assert all(is_named_range(nr) for nr in named_ranges)
assert all(isinstance(nr, NamedRange) for nr in named_ranges)
if isinstance(index, Dimension):
dim_index = self.dim_index(index)
if dim_index is None:
Expand All @@ -515,9 +515,10 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:
)
if index < 0:
index += len(self.dims)
new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ())
dims = self.dims[:index] + new_dims + self.dims[index + 1 :]
ranges = self.ranges[:index] + new_ranges + self.ranges[index + 1 :]
new_dims = (arg.dim for arg in named_ranges) if len(named_ranges) > 0 else ()
new_ranges = (arg.unit_range for arg in named_ranges) if len(named_ranges) > 0 else ()
dims = self.dims[:index] + tuple(new_dims) + self.dims[index + 1 :]
ranges = self.ranges[:index] + tuple(new_ranges) + self.ranges[index + 1 :]

return Domain(dims=dims, ranges=ranges)

Expand Down Expand Up @@ -559,10 +560,7 @@ def domain(domain_like: DomainLike) -> Domain:
if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()):
return Domain(
dims=tuple(domain_like.keys()),
ranges=tuple(
UnitRange(0, s) # type: ignore[arg-type] # type of `s` is checked in condition
for s in domain_like.values()
),
ranges=tuple(UnitRange(0, s) for s in domain_like.values()),
)
return Domain(
dims=tuple(domain_like.keys()),
Expand Down Expand Up @@ -949,15 +947,15 @@ def from_offset(

def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]:
if not isinstance(image_range, UnitRange):
if image_range[0] != self.codomain:
if image_range.dim != self.codomain:
raise ValueError(
f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'."
f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'."
)

image_range = image_range[1]
image_range = image_range.unit_range

assert isinstance(image_range, UnitRange)
return ((self.codomain, image_range - self.offset),)
return (named_range((self.codomain, image_range - self.offset)),)

def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField:
raise NotImplementedError()
Expand Down
42 changes: 19 additions & 23 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _relative_sub_domain(
if isinstance(idx, slice):
try:
sliced = _slice_range(rng, idx)
named_ranges.append((dim, sliced))
named_ranges.append(common.NamedRange(dim, sliced))
except IndexError as ex:
raise embedded_exceptions.IndexOutOfBounds(
domain=domain, indices=index, index=idx, dim=dim
Expand All @@ -76,14 +76,14 @@ def _absolute_sub_domain(
for i, (dim, rng) in enumerate(domain):
if (pos := _find_index_of_dim(dim, index)) is not None:
named_idx = index[pos]
idx = named_idx[1]
_, idx = named_idx
if isinstance(idx, common.UnitRange):
if not idx <= rng:
raise embedded_exceptions.IndexOutOfBounds(
domain=domain, indices=index, index=named_idx, dim=dim
)

named_ranges.append((dim, idx))
named_ranges.append(common.NamedRange(dim, idx))
else:
# not in new domain
assert common.is_int_index(idx)
Expand All @@ -93,7 +93,7 @@ def _absolute_sub_domain(
)
else:
# dimension not mentioned in slice
named_ranges.append((dim, domain.ranges[i]))
named_ranges.append(common.NamedRange(dim, domain.ranges[i]))

return common.Domain(*named_ranges)

Expand Down Expand Up @@ -137,23 +137,23 @@ def restrict_to_intersection(
"""
ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,)
intersection_without_ignore_dims = domain_intersection(*[
common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple])
common.Domain(*[nr for nr in domain if nr.dim not in ignore_dims_tuple])
for domain in domains
])
return tuple(
common.Domain(*[
(d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1])
for d, r in domain
(nr if nr.dim in ignore_dims_tuple else intersection_without_ignore_dims[nr.dim])
for nr in domain
])
for domain in domains
)


def iterate_domain(
domain: common.Domain,
) -> Iterator[tuple[tuple[common.Dimension, int]]]:
for i in itertools.product(*[list(r) for r in domain.ranges]):
yield tuple(zip(domain.dims, i)) # type: ignore[misc] # trust me, `i` is `tuple[int, ...]`
) -> Iterator[tuple[common.NamedIndex]]:
for idx in itertools.product(*(list(r) for r in domain.ranges)):
yield tuple(common.NamedIndex(d, i) for d, i in zip(domain.dims, idx)) # type: ignore[misc] # trust me, `idx` is `tuple[int, ...]`


def _expand_ellipsis(
Expand All @@ -169,7 +169,7 @@ def _expand_ellipsis(

def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange:
if slice_obj == slice(None):
return common.UnitRange(input_range.start, input_range.stop)
return input_range

start = (
input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop
Expand Down Expand Up @@ -209,20 +209,16 @@ def _named_slice_to_named_range(
) -> common.NamedRange | common.NamedSlice:
assert hasattr(idx, "start") and hasattr(idx, "stop")
if common.is_named_slice(idx):
idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = (
idx.start[0], # type: ignore[attr-defined]
idx.start[1], # type: ignore[attr-defined]
idx.stop[0], # type: ignore[attr-defined]
idx.stop[1], # type: ignore[attr-defined]
)
if idx_start_0 != idx_stop_0:
start_dim, start_value = idx.start
stop_dim, stop_value = idx.stop
if start_dim != stop_dim:
raise IndexError(
f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'."
f"Dimensions slicing mismatch between '{start_dim.value}' and '{stop_dim.value}'."
)
assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int)
return (idx_start_0, common.UnitRange(idx_start_1, idx_stop_1))
if common.is_named_index(idx.start) and idx.stop is None:
assert isinstance(start_value, int) and isinstance(stop_value, int)
return common.NamedRange(start_dim, common.UnitRange(start_value, stop_value))
if isinstance(idx.start, common.NamedIndex) and idx.stop is None:
raise IndexError(f"Upper bound needs to be specified for {idx}.")
if common.is_named_index(idx.stop) and idx.start is None:
if isinstance(idx.stop, common.NamedIndex) and idx.start is None:
raise IndexError(f"Lower bound needs to be specified for {idx}.")
return idx
Loading

0 comments on commit e344afd

Please sign in to comment.