Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[next]: NamedRange/NamedIndex tuple to NamedTuple #1490

Merged
merged 25 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 60 additions & 62 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ClassVar,
Final,
Generic,
NamedTuple,
Never,
Optional,
ParamSpec,
Expand Down Expand Up @@ -84,7 +85,7 @@ def __str__(self) -> str:
return f"{self.value}[{self.kind}]"

def __call__(self, val: int) -> NamedIndex:
return self, val
return NamedIndex(self, val)


class Infinity(enum.Enum):
Expand Down Expand Up @@ -248,9 +249,16 @@ def __str__(self) -> str:

FiniteUnitRange: TypeAlias = UnitRange[int, int]

_Rng = TypeVar(
"_Rng",
FiniteUnitRange,
UnitRange[Infinity, int],
UnitRange[int, Infinity],
UnitRange[Infinity, Infinity],
)

RangeLike: TypeAlias = (
UnitRange
_Rng
| range
| tuple[core_defs.IntegralScalar, core_defs.IntegralScalar]
| core_defs.IntegralScalar
Expand Down Expand Up @@ -282,10 +290,26 @@ def unit_range(r: RangeLike) -> UnitRange:
raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.")


class NamedRange(NamedTuple, Generic[_Rng]):
dim: Dimension
unit_range: _Rng

def __str__(self) -> str:
return f"{self.dim}={self.unit_range}"


IntIndex: TypeAlias = int | core_defs.IntegralScalar
NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple
NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple
FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple


class NamedIndex(NamedTuple):
dim: Dimension
value: IntIndex

def __str__(self) -> str:
return f"{self.dim}={self.value}"


FiniteNamedRange: TypeAlias = NamedRange[FiniteUnitRange]
RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType
NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice
Expand All @@ -304,41 +328,22 @@ def is_int_index(p: Any) -> TypeGuard[IntIndex]:
return isinstance(p, (int, core_defs.INTEGRAL_TYPES))


def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
return (
isinstance(v, tuple)
and len(v) == 2
and isinstance(v[0], Dimension)
and isinstance(v[1], UnitRange)
)


def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]:
return UnitRange.is_finite(v[1])
return UnitRange.is_finite(v.unit_range)


def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
return (
isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1])
def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[slice]:
return isinstance(obj, slice) and (
isinstance(obj.start, NamedIndex) and isinstance(obj.stop, NamedIndex)
)


def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]:
return isinstance(obj, slice) and (is_named_index(obj.start) and is_named_index(obj.stop))


def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]:
return (
is_int_index(v)
or is_named_range(v)
or is_named_index(v)
or isinstance(v, slice)
or v is Ellipsis
)
return is_int_index(v) or isinstance(v, (NamedRange, NamedIndex, slice)) or v is Ellipsis


def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]:
return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v)
return isinstance(v, Sequence) and all(isinstance(e, (NamedRange, NamedIndex)) for e in v)


def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]:
Expand All @@ -356,28 +361,21 @@ def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence:


def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange:
return (v[0], unit_range(v[1]))


_Rng = TypeVar(
"_Rng",
UnitRange[int, int],
UnitRange[Infinity, int],
UnitRange[int, Infinity],
UnitRange[Infinity, Infinity],
)
if isinstance(v, NamedRange):
return v
return NamedRange(v[0], unit_range(v[1]))


@dataclasses.dataclass(frozen=True, init=False)
class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]):
class Domain(Sequence[NamedRange[_Rng]], Generic[_Rng]):
"""Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s."""

dims: tuple[Dimension, ...]
ranges: tuple[_Rng, ...]

def __init__(
self,
*args: tuple[Dimension, _Rng],
*args: NamedRange[_Rng],
dims: Optional[Sequence[Dimension]] = None,
ranges: Optional[Sequence[_Rng]] = None,
) -> None:
Expand Down Expand Up @@ -406,7 +404,7 @@ def __init__(
object.__setattr__(self, "dims", tuple(dims))
object.__setattr__(self, "ranges", tuple(ranges))
else:
if not all(is_named_range(arg) for arg in args):
if not all(isinstance(arg, NamedRange) for arg in args):
raise ValueError(
f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'."
)
Expand Down Expand Up @@ -437,25 +435,25 @@ def is_empty(self) -> bool:
return any(rng.is_empty() for rng in self.ranges)

@overload
def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ...
def __getitem__(self, index: int) -> NamedRange: ...

@overload
def __getitem__(self, index: slice) -> Self: ...

@overload
def __getitem__(self, index: Dimension) -> tuple[Dimension, _Rng]: ...
def __getitem__(self, index: Dimension) -> NamedRange: ...

def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain:
if isinstance(index, int):
return self.dims[index], self.ranges[index]
return NamedRange(dim=self.dims[index], unit_range=self.ranges[index])
elif isinstance(index, slice):
dims_slice = self.dims[index]
ranges_slice = self.ranges[index]
return Domain(dims=dims_slice, ranges=ranges_slice)
elif isinstance(index, Dimension):
try:
index_pos = self.dims.index(index)
return self.dims[index_pos], self.ranges[index_pos]
return NamedRange(dim=self.dims[index_pos], unit_range=self.ranges[index_pos])
except ValueError as ex:
raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex
else:
Expand All @@ -470,10 +468,12 @@ def __and__(self, other: Domain) -> Domain:
>>> I = Dimension("I")
>>> J = Dimension("J")

>>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6)))
>>> Domain(NamedRange(I, UnitRange(-1, 3))) & Domain(NamedRange(I, UnitRange(1, 6)))
Domain(dims=(Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>),), ranges=(UnitRange(1, 3),))

>>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6)))
>>> Domain(NamedRange(I, UnitRange(-1, 3)), NamedRange(J, UnitRange(2, 4))) & Domain(
... NamedRange(I, UnitRange(1, 6))
... )
Domain(dims=(Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)), ranges=(UnitRange(1, 3), UnitRange(2, 4)))
"""
broadcast_dims = tuple(promote_dims(self.dims, other.dims))
Expand All @@ -487,7 +487,7 @@ def __and__(self, other: Domain) -> Domain:
return Domain(dims=broadcast_dims, ranges=intersected_ranges)

def __str__(self) -> str:
return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})"
return f"Domain({', '.join(f'{e}' for e in self)})"

def dim_index(self, dim: Dimension) -> Optional[int]:
return self.dims.index(dim) if dim in self.dims else None
Expand All @@ -503,7 +503,7 @@ def insert(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:
return self.replace(index, *named_ranges)

def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:
assert all(is_named_range(nr) for nr in named_ranges)
assert all(isinstance(nr, NamedRange) for nr in named_ranges)
if isinstance(index, Dimension):
dim_index = self.dim_index(index)
if dim_index is None:
Expand All @@ -515,9 +515,10 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:
)
if index < 0:
index += len(self.dims)
new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ())
dims = self.dims[:index] + new_dims + self.dims[index + 1 :]
ranges = self.ranges[:index] + new_ranges + self.ranges[index + 1 :]
new_dims = (arg.dim for arg in named_ranges) if len(named_ranges) > 0 else ()
new_ranges = (arg.unit_range for arg in named_ranges) if len(named_ranges) > 0 else ()
dims = self.dims[:index] + tuple(new_dims) + self.dims[index + 1 :]
ranges = self.ranges[:index] + tuple(new_ranges) + self.ranges[index + 1 :]

return Domain(dims=dims, ranges=ranges)

Expand Down Expand Up @@ -559,10 +560,7 @@ def domain(domain_like: DomainLike) -> Domain:
if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()):
return Domain(
dims=tuple(domain_like.keys()),
ranges=tuple(
UnitRange(0, s) # type: ignore[arg-type] # type of `s` is checked in condition
for s in domain_like.values()
),
ranges=tuple(UnitRange(0, s) for s in domain_like.values()),
)
return Domain(
dims=tuple(domain_like.keys()),
Expand Down Expand Up @@ -949,15 +947,15 @@ def from_offset(

def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]:
if not isinstance(image_range, UnitRange):
if image_range[0] != self.codomain:
if image_range.dim != self.codomain:
raise ValueError(
f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'."
f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'."
)

image_range = image_range[1]
image_range = image_range.unit_range

assert isinstance(image_range, UnitRange)
return ((self.codomain, image_range - self.offset),)
return (named_range((self.codomain, image_range - self.offset)),)

def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField:
raise NotImplementedError()
Expand Down
42 changes: 19 additions & 23 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _relative_sub_domain(
if isinstance(idx, slice):
try:
sliced = _slice_range(rng, idx)
named_ranges.append((dim, sliced))
named_ranges.append(common.NamedRange(dim, sliced))
except IndexError as ex:
raise embedded_exceptions.IndexOutOfBounds(
domain=domain, indices=index, index=idx, dim=dim
Expand All @@ -76,14 +76,14 @@ def _absolute_sub_domain(
for i, (dim, rng) in enumerate(domain):
if (pos := _find_index_of_dim(dim, index)) is not None:
named_idx = index[pos]
idx = named_idx[1]
_, idx = named_idx
if isinstance(idx, common.UnitRange):
if not idx <= rng:
raise embedded_exceptions.IndexOutOfBounds(
domain=domain, indices=index, index=named_idx, dim=dim
)

named_ranges.append((dim, idx))
named_ranges.append(common.NamedRange(dim, idx))
else:
# not in new domain
assert common.is_int_index(idx)
Expand All @@ -93,7 +93,7 @@ def _absolute_sub_domain(
)
else:
# dimension not mentioned in slice
named_ranges.append((dim, domain.ranges[i]))
named_ranges.append(common.NamedRange(dim, domain.ranges[i]))

return common.Domain(*named_ranges)

Expand Down Expand Up @@ -137,23 +137,23 @@ def restrict_to_intersection(
"""
ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,)
intersection_without_ignore_dims = domain_intersection(*[
common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple])
common.Domain(*[nr for nr in domain if nr.dim not in ignore_dims_tuple])
for domain in domains
])
return tuple(
common.Domain(*[
(d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1])
for d, r in domain
(nr if nr.dim in ignore_dims_tuple else intersection_without_ignore_dims[nr.dim])
for nr in domain
])
for domain in domains
)


def iterate_domain(
domain: common.Domain,
) -> Iterator[tuple[tuple[common.Dimension, int]]]:
for i in itertools.product(*[list(r) for r in domain.ranges]):
yield tuple(zip(domain.dims, i)) # type: ignore[misc] # trust me, `i` is `tuple[int, ...]`
) -> Iterator[tuple[common.NamedIndex]]:
for idx in itertools.product(*(list(r) for r in domain.ranges)):
yield tuple(common.NamedIndex(d, i) for d, i in zip(domain.dims, idx)) # type: ignore[misc] # trust me, `idx` is `tuple[int, ...]`


def _expand_ellipsis(
Expand All @@ -169,7 +169,7 @@ def _expand_ellipsis(

def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange:
if slice_obj == slice(None):
return common.UnitRange(input_range.start, input_range.stop)
return input_range

start = (
input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop
Expand Down Expand Up @@ -209,20 +209,16 @@ def _named_slice_to_named_range(
) -> common.NamedRange | common.NamedSlice:
assert hasattr(idx, "start") and hasattr(idx, "stop")
if common.is_named_slice(idx):
idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = (
idx.start[0], # type: ignore[attr-defined]
idx.start[1], # type: ignore[attr-defined]
idx.stop[0], # type: ignore[attr-defined]
idx.stop[1], # type: ignore[attr-defined]
)
if idx_start_0 != idx_stop_0:
start_dim, start_value = idx.start
stop_dim, stop_value = idx.stop
if start_dim != stop_dim:
raise IndexError(
f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'."
f"Dimensions slicing mismatch between '{start_dim.value}' and '{stop_dim.value}'."
)
assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int)
return (idx_start_0, common.UnitRange(idx_start_1, idx_stop_1))
if common.is_named_index(idx.start) and idx.stop is None:
assert isinstance(start_value, int) and isinstance(stop_value, int)
return common.NamedRange(start_dim, common.UnitRange(start_value, stop_value))
if isinstance(idx.start, common.NamedIndex) and idx.stop is None:
raise IndexError(f"Upper bound needs to be specified for {idx}.")
if common.is_named_index(idx.stop) and idx.start is None:
if isinstance(idx.stop, common.NamedIndex) and idx.start is None:
raise IndexError(f"Lower bound needs to be specified for {idx}.")
return idx
Loading
Loading