From fd7bd24cd28e3a91a7d0dfecdaff4ae98ed03bbe Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 13 Mar 2024 15:24:58 +0100 Subject: [PATCH 01/23] edits for NamedRange class --- src/gt4py/next/common.py | 26 +++++++++++++++++--------- src/gt4py/next/iterator/embedded.py | 14 ++++++++------ 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 0aa19b20ae..43704b7931 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -277,9 +277,17 @@ def unit_range(r: RangeLike) -> UnitRange: raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") +@dataclasses.dataclass(frozen=True) +class NamedRange: + dims: Dimension + urange: UnitRange + + def __str__(self) -> str: + return f"{self.dims}={self.urange}" + + IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple -NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/ @@ -309,7 +317,7 @@ def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: - return UnitRange.is_finite(v[1]) + return UnitRange.is_finite(v.urange) def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: @@ -351,7 +359,7 @@ def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: - return (v[0], unit_range(v[1])) + return NamedRange(dims=v[0], urange=unit_range(v[1])) _Rng = TypeVar( @@ -439,7 +447,7 @@ def __getitem__(self, index: Dimension) -> tuple[Dimension, _Rng]: ... def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: if isinstance(index, int): - return self.dims[index], self.ranges[index] + return named_range((self.dims[index], self.ranges[index])) elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] @@ -447,7 +455,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) - return self.dims[index_pos], self.ranges[index_pos] + return named_range((self.dims[index_pos], self.ranges[index_pos])) except ValueError as ex: raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex else: @@ -957,15 +965,15 @@ def from_offset( def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: if not isinstance(image_range, UnitRange): - if image_range[0] != self.codomain: + if image_range.dims != self.codomain: raise ValueError( - f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." + f"Dimension '{image_range.dims}' does not match the codomain dimension '{self.codomain}'." ) - image_range = image_range[1] + image_range = image_range.urange assert isinstance(image_range, UnitRange) - return ((self.codomain, image_range - self.offset),) + return (named_range((self.codomain, image_range - self.offset)),) def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField: raise NotImplementedError() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index f9f1ba47e0..b66f577afb 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -205,7 +205,9 @@ def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: self.kstart = kstart assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673 column_range: common.NamedRange = column_range_cvar.get() - self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range[1]), data) + self.data = ( + data if isinstance(data, np.ndarray) else np.full(len(column_range.urange), data) + ) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] @@ -746,7 +748,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().urange assert column_range is not None col: list[ @@ -823,7 +825,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1])) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range.urange)) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -864,7 +866,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().urange # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -1090,7 +1092,7 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - def restrict(self, item: common.AnyIndexSpec) -> common.Field: if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off - d, r = item[0] + d, r = item.dims assert d == self._dimension assert isinstance(r, core_defs.INTEGRAL_TYPES) return self.__class__(self._dimension, r) # type: ignore[arg-type] # not sure why the assert above does not work @@ -1489,7 +1491,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().urange if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") From b48c98af992759405b0f96000c1fbd5163e9fa35 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 13 Mar 2024 16:27:37 +0100 Subject: [PATCH 02/23] edits for NamedRange class --- src/gt4py/next/common.py | 21 +++++++------- src/gt4py/next/embedded/common.py | 34 +++++++++++------------ src/gt4py/next/embedded/nd_array_field.py | 14 ++++++---- src/gt4py/next/iterator/embedded.py | 8 +++--- 4 files changed, 40 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 43704b7931..0b28d688da 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -279,11 +279,11 @@ def unit_range(r: RangeLike) -> UnitRange: @dataclasses.dataclass(frozen=True) class NamedRange: - dims: Dimension + dim: Dimension urange: UnitRange def __str__(self) -> str: - return f"{self.dims}={self.urange}" + return f"{self.dim}={self.urange}" IntIndex: TypeAlias = int | core_defs.IntegralScalar @@ -309,10 +309,9 @@ def is_int_index(p: Any) -> TypeGuard[IntIndex]: def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: return ( - isinstance(v, tuple) - and len(v) == 2 - and isinstance(v[0], Dimension) - and isinstance(v[1], UnitRange) + isinstance(v, NamedRange) + and isinstance(v.dim, Dimension) + and isinstance(v.urange, UnitRange) ) @@ -359,7 +358,7 @@ def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: - return NamedRange(dims=v[0], urange=unit_range(v[1])) + return NamedRange(dim=v[0], urange=unit_range(v[1])) _Rng = TypeVar( @@ -413,7 +412,9 @@ def __init__( raise ValueError( f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." ) - dims, ranges = zip(*args) if args else ((), ()) + dims = (args[i].dim for i in range(len(args))) if args else () + ranges = (args[i].urange for i in range(len(args))) if args else () + # dims, ranges = zip(*args) if args else ((), ()) object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) @@ -965,9 +966,9 @@ def from_offset( def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: if not isinstance(image_range, UnitRange): - if image_range.dims != self.codomain: + if image_range.dim != self.codomain: raise ValueError( - f"Dimension '{image_range.dims}' does not match the codomain dimension '{self.codomain}'." + f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'." ) image_range = image_range.urange diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 0ba39d377a..1c0a93e91b 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -46,23 +46,23 @@ def _relative_sub_domain( f"Can not access dimension with index {index} of 'Field' with {len(domain)} dimensions." ) expanded += (slice(None),) * (len(domain) - len(expanded)) - for (dim, rng), idx in zip(domain, expanded, strict=True): + for dom, idx in zip(domain, expanded, strict=True): if isinstance(idx, slice): try: - sliced = _slice_range(rng, idx) - named_ranges.append((dim, sliced)) + sliced = _slice_range(dom.urange, idx) + named_ranges.append(common.named_range((dom.dim, sliced))) except IndexError as ex: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=idx, dim=dim + domain=domain, indices=index, index=idx, dim=dom.dim ) from ex else: # not in new domain assert common.is_int_index(idx) - assert common.UnitRange.is_finite(rng) - new_index = (rng.start if idx >= 0 else rng.stop) + idx - if new_index < rng.start or new_index >= rng.stop: + assert common.UnitRange.is_finite(dom.urange) + new_index = (dom.urange.start if idx >= 0 else dom.urange.stop) + idx + if new_index < dom.urange.start or new_index >= dom.urange.stop: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=idx, dim=dim + domain=domain, indices=index, index=idx, dim=dom.dim ) return common.Domain(*named_ranges) @@ -71,28 +71,28 @@ def _relative_sub_domain( def _absolute_sub_domain( domain: common.Domain, index: common.AbsoluteIndexSequence ) -> common.Domain: - named_ranges: list[common.NamedRange] = [] - for i, (dim, rng) in enumerate(domain): - if (pos := _find_index_of_dim(dim, index)) is not None: + named_ranges: list[tuple[common.Dimension, Any]] = [] + for i in range(domain.ndim): + if (pos := _find_index_of_dim(domain.dims[i], index)) is not None: named_idx = index[pos] idx = named_idx[1] if isinstance(idx, common.UnitRange): - if not idx <= rng: + if not idx <= domain.ranges[i]: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=named_idx, dim=dim + domain=domain, indices=index, index=named_idx, dim=domain.dims[i] ) - named_ranges.append((dim, idx)) + named_ranges.append((domain.dims[i], idx)) else: # not in new domain assert common.is_int_index(idx) - if idx < rng.start or idx >= rng.stop: + if idx < domain.ranges[i].start or idx >= domain.ranges[i].stop: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=named_idx, dim=dim + domain=domain, indices=index, index=named_idx, dim=domain.dims[i] ) else: # dimension not mentioned in slice - named_ranges.append((dim, domain.ranges[i])) + named_ranges.append((domain.dims[i], domain.ranges[i])) return common.Domain(*named_ranges) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 1760cb17e8..99378e2b8c 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -116,7 +116,7 @@ def shape(self) -> tuple[int, ...]: @property def __gt_origin__(self) -> tuple[int, ...]: assert common.Domain.is_finite(self._domain) - return tuple(-r.start for _, r in self._domain) + return tuple(-r.start for r in self._domain.ranges) @property def ndarray(self) -> core_defs.NDArrayObject: @@ -407,12 +407,12 @@ def inverse_image( if not isinstance( image_range, common.UnitRange ): # TODO(havogt): cleanup duplication with CartesianConnectivity - if image_range[0] != self.codomain: + if image_range.dim != self.codomain: raise ValueError( - f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." + f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'." ) - image_range = image_range[1] + image_range = image_range.urange assert isinstance(image_range, common.UnitRange) @@ -690,8 +690,10 @@ def _get_slices_from_domain_slice( """ slice_indices: list[slice | common.IntIndex] = [] - for pos_old, (dim, _) in enumerate(domain): - if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None: + for pos_old in range(domain.ndim): + if ( + pos := embedded_common._find_index_of_dim(domain.dims[pos_old], domain_slice) + ) is not None: index_or_range = domain_slice[pos][1] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b66f577afb..202be2b10c 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -748,7 +748,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get().urange + column_range = column_range_cvar.get()[1] assert column_range is not None col: list[ @@ -825,7 +825,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range.urange)) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1])) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -866,7 +866,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get().urange + column_range = column_range_cvar.get()[1] # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -1491,7 +1491,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get().urange + column_range = column_range_cvar.get()[1] if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") From 84c28a9a1d6b353004e8f7401129ad24db742ca4 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 13 Mar 2024 17:20:16 +0100 Subject: [PATCH 03/23] further edits --- src/gt4py/next/common.py | 11 +++++------ src/gt4py/next/embedded/common.py | 1 + src/gt4py/next/iterator/embedded.py | 4 +++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 0b28d688da..0f6138213b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -379,7 +379,7 @@ class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]): def __init__( self, - *args: tuple[Dimension, _Rng], + *args: NamedRange, dims: Optional[Sequence[Dimension]] = None, ranges: Optional[Sequence[_Rng]] = None, ) -> None: @@ -412,11 +412,10 @@ def __init__( raise ValueError( f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." ) - dims = (args[i].dim for i in range(len(args))) if args else () - ranges = (args[i].urange for i in range(len(args))) if args else () - # dims, ranges = zip(*args) if args else ((), ()) - object.__setattr__(self, "dims", tuple(dims)) - object.__setattr__(self, "ranges", tuple(ranges)) + dims_new = (arg.dim for arg in args) if args else () + ranges_new = (arg.urange for arg in args) if args else () + object.__setattr__(self, "dims", tuple(dims_new)) + object.__setattr__(self, "ranges", tuple(ranges_new)) if len(set(self.dims)) != len(self.dims): raise NotImplementedError(f"Domain dimensions must be unique, not '{self.dims}'.") diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 1c0a93e91b..1cbf5d6e42 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -47,6 +47,7 @@ def _relative_sub_domain( ) expanded += (slice(None),) * (len(domain) - len(expanded)) for dom, idx in zip(domain, expanded, strict=True): + assert isinstance(dom, common.NamedRange) if isinstance(idx, slice): try: sliced = _slice_range(dom.urange, idx) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 202be2b10c..0c2ce27808 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1491,7 +1491,9 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get()[1] + columns = column_range_cvar.get() + assert isinstance(columns, tuple) + column_range = columns[1] if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") From 40f07fc9be7f880a5be42a3bd95bd7dd9cb72177 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 13 Mar 2024 17:23:33 +0100 Subject: [PATCH 04/23] small edit --- src/gt4py/next/iterator/embedded.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 0c2ce27808..02f57f04ed 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -908,7 +908,7 @@ def _translate_named_indices( domain_slice: list[common.NamedRange | common.NamedIndex] = [] for d, v in named_indices.items(): if isinstance(v, range): - domain_slice.append((d, common.UnitRange(v.start, v.stop))) + domain_slice.append(common.named_range((d, common.UnitRange(v.start, v.stop)))) elif isinstance(v, list): assert len(v) == 1 # only 1 sparse dimension is supported assert common.is_int_index( From e77192638dc67f6b51c8d28599a3bbd136bc6f4c Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 14 Mar 2024 10:48:19 +0100 Subject: [PATCH 05/23] further changes to fix tests --- src/gt4py/next/common.py | 8 +++++--- src/gt4py/next/embedded/common.py | 13 ++++++++++--- src/gt4py/next/embedded/nd_array_field.py | 6 +++--- src/gt4py/next/embedded/operators.py | 2 +- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 0f6138213b..5d68eeb834 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -515,9 +515,11 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: ) if index < 0: index += len(self.dims) - new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ()) - dims = self.dims[:index] + new_dims + self.dims[index + 1 :] - ranges = self.ranges[:index] + new_ranges + self.ranges[index + 1 :] + new_dims = (arg.dim for arg in named_ranges) if len(named_ranges) > 0 else () + new_ranges = (arg.urange for arg in named_ranges) if len(named_ranges) > 0 else () + # new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ()) + dims = self.dims[:index] + tuple(new_dims) + self.dims[index + 1 :] + ranges = self.ranges[:index] + tuple(new_ranges) + self.ranges[index + 1 :] return Domain(dims=dims, ranges=ranges) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 1cbf5d6e42..13c466913a 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -76,14 +76,17 @@ def _absolute_sub_domain( for i in range(domain.ndim): if (pos := _find_index_of_dim(domain.dims[i], index)) is not None: named_idx = index[pos] - idx = named_idx[1] + if not isinstance(named_idx, common.NamedRange): + idx = named_idx[1] + else: + idx = named_idx.urange if isinstance(idx, common.UnitRange): if not idx <= domain.ranges[i]: raise embedded_exceptions.IndexOutOfBounds( domain=domain, indices=index, index=named_idx, dim=domain.dims[i] ) - named_ranges.append((domain.dims[i], idx)) + named_ranges.append(common.named_range((domain.dims[i], idx))) else: # not in new domain assert common.is_int_index(idx) @@ -93,7 +96,7 @@ def _absolute_sub_domain( ) else: # dimension not mentioned in slice - named_ranges.append((domain.dims[i], domain.ranges[i])) + named_ranges.append(common.named_range((domain.dims[i], domain.ranges[i]))) return common.Domain(*named_ranges) @@ -143,6 +146,10 @@ def _find_index_of_dim( dim: common.Dimension, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], ) -> Optional[int]: + if isinstance(domain_slice, common.Domain): + for i in range(domain_slice.ndim): + if domain_slice.dims[i] == dim: + return i for i, (d, _) in enumerate(domain_slice): if dim == d: return i diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 99378e2b8c..2ef00895f5 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -190,7 +190,7 @@ def remap( if dim_idx is None: raise ValueError(f"Incompatible index field, expected a field with dimension '{dim}'.") - current_range: common.UnitRange = self.domain[dim_idx][1] + current_range: common.UnitRange = self.domain.ranges[dim_idx] new_ranges = connectivity.inverse_image(current_range) new_domain = self.domain.replace(dim_idx, *new_ranges) @@ -545,7 +545,7 @@ def _builtin_op( axis.value ] # assumes offset and local dimension have same name assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) - new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) + new_domain = common.Domain(*[common.named_range((field.domain.dims[idx], field.domain.ranges[idx])) for idx, dim in enumerate(field.domain.dims) if dim != axis]) broadcast_slice = tuple( slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis @@ -694,7 +694,7 @@ def _get_slices_from_domain_slice( if ( pos := embedded_common._find_index_of_dim(domain.dims[pos_old], domain_slice) ) is not None: - index_or_range = domain_slice[pos][1] + index_or_range = domain_slice[pos][1] if isinstance(domain_slice, tuple) else domain_slice.ranges[pos] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: slice_indices.append(slice(None)) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 0982024090..2816df9542 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -125,7 +125,7 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: - vertical_dim_filtered = [nr for nr in domain if nr[0].kind == common.DimensionKind.VERTICAL] + vertical_dim_filtered = [nr for nr in domain.dims if nr.kind == common.DimensionKind.VERTICAL] assert len(vertical_dim_filtered) <= 1 return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING From b6d39f7fbb1f89330f381617b17294cb7460bcb2 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 14 Mar 2024 11:26:58 +0100 Subject: [PATCH 06/23] further edits --- src/gt4py/next/common.py | 1 - src/gt4py/next/embedded/common.py | 2 +- src/gt4py/next/embedded/nd_array_field.py | 12 ++++++++++-- src/gt4py/next/embedded/operators.py | 4 ++-- src/gt4py/next/iterator/embedded.py | 16 ++++++++-------- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 5d68eeb834..4df54b2453 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -517,7 +517,6 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: index += len(self.dims) new_dims = (arg.dim for arg in named_ranges) if len(named_ranges) > 0 else () new_ranges = (arg.urange for arg in named_ranges) if len(named_ranges) > 0 else () - # new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ()) dims = self.dims[:index] + tuple(new_dims) + self.dims[index + 1 :] ranges = self.ranges[:index] + tuple(new_ranges) + self.ranges[index + 1 :] diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 13c466913a..e9bd8db8a5 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -72,7 +72,7 @@ def _relative_sub_domain( def _absolute_sub_domain( domain: common.Domain, index: common.AbsoluteIndexSequence ) -> common.Domain: - named_ranges: list[tuple[common.Dimension, Any]] = [] + named_ranges: list[common.NamedRange] = [] for i in range(domain.ndim): if (pos := _find_index_of_dim(domain.dims[i], index)) is not None: named_idx = index[pos] diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 2ef00895f5..d424d21004 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -545,7 +545,11 @@ def _builtin_op( axis.value ] # assumes offset and local dimension have same name assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) - new_domain = common.Domain(*[common.named_range((field.domain.dims[idx], field.domain.ranges[idx])) for idx, dim in enumerate(field.domain.dims) if dim != axis]) + new_domain = common.Domain(*[ + common.named_range((field.domain.dims[idx], field.domain.ranges[idx])) + for idx, dim in enumerate(field.domain.dims) + if dim != axis + ]) broadcast_slice = tuple( slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis @@ -694,7 +698,11 @@ def _get_slices_from_domain_slice( if ( pos := embedded_common._find_index_of_dim(domain.dims[pos_old], domain_slice) ) is not None: - index_or_range = domain_slice[pos][1] if isinstance(domain_slice, tuple) else domain_slice.ranges[pos] + index_or_range = ( + domain_slice[pos][1] + if isinstance(domain_slice, tuple) + else domain_slice.ranges[pos] + ) slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: slice_indices.append(slice(None)) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 2816df9542..a21ada02cb 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -48,8 +48,8 @@ def __call__( # type: ignore[override] **kwargs: common.Field | core_defs.Scalar, # type: ignore[override] ) -> common.Field: scan_range = embedded_context.closure_column_range.get() - assert self.axis == scan_range[0] - scan_axis = scan_range[0] + assert self.axis == scan_range.dim + scan_axis = scan_range.dim all_args = [*args, *kwargs.values()] domain_intersection = _intersect_scan_args(*all_args) non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 02f57f04ed..f77427ade7 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -748,7 +748,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().urange assert column_range is not None col: list[ @@ -825,7 +825,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1])) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range.urange)) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -866,7 +866,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().urange # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -1059,7 +1059,7 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override @property def domain(self) -> common.Domain: if self._cur_index is None: - return common.Domain((self._dimension, common.UnitRange.infinite())) + return common.Domain(common.named_range((self._dimension, common.UnitRange.infinite()))) else: return common.Domain() @@ -1492,8 +1492,8 @@ def _column_dtype(elem: Any) -> np.dtype: def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): columns = column_range_cvar.get() - assert isinstance(columns, tuple) - column_range = columns[1] + assert isinstance(columns, common.NamedRange) + column_range = columns.urange if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") @@ -1546,10 +1546,10 @@ def closure( column = ColumnDescriptor(column_axis.value, domain[column_axis.value]) del domain[column_axis.value] - column_range = ( + column_range = common.named_range(( column_axis, common.UnitRange(column.col_range.start, column.col_range.stop), - ) + )) out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) From 6ed4cfdb9ec3fd74d96bacb2c17a63c796a9c170 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Thu, 14 Mar 2024 12:20:40 +0100 Subject: [PATCH 07/23] Update src/gt4py/next/common.py Co-authored-by: Hannes Vogt --- src/gt4py/next/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 4df54b2453..b48926f23b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -278,9 +278,9 @@ def unit_range(r: RangeLike) -> UnitRange: @dataclasses.dataclass(frozen=True) -class NamedRange: +class NamedRange(Generic[_Rng]): dim: Dimension - urange: UnitRange + urange: _Rng def __str__(self) -> str: return f"{self.dim}={self.urange}" From 8909221151d0931b49228c6f74bc55d52ccdfcf8 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 14 Mar 2024 13:30:53 +0100 Subject: [PATCH 08/23] further edits --- src/gt4py/next/common.py | 27 ++++++++++----------- src/gt4py/next/embedded/common.py | 2 +- src/gt4py/next/embedded/nd_array_field.py | 6 ++--- src/gt4py/next/embedded/operators.py | 29 +++++++++++++++++++---- src/gt4py/next/iterator/embedded.py | 5 ++-- 5 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index b48926f23b..fa1edc2efb 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -252,6 +252,14 @@ def __str__(self) -> str: | None ) +_Rng = TypeVar( + "_Rng", + UnitRange[int, int], + UnitRange[Infinity, int], + UnitRange[int, Infinity], + UnitRange[Infinity, Infinity], +) + def unit_range(r: RangeLike) -> UnitRange: if isinstance(r, UnitRange): @@ -288,7 +296,7 @@ def __str__(self) -> str: IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple -FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple +FiniteNamedRange: TypeAlias = NamedRange[FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/ AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice @@ -319,7 +327,7 @@ def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: return UnitRange.is_finite(v.urange) -def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: +def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedIndex]: return ( isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) ) @@ -361,15 +369,6 @@ def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: return NamedRange(dim=v[0], urange=unit_range(v[1])) -_Rng = TypeVar( - "_Rng", - UnitRange[int, int], - UnitRange[Infinity, int], - UnitRange[int, Infinity], - UnitRange[Infinity, Infinity], -) - - @dataclasses.dataclass(frozen=True, init=False) class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]): """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" @@ -445,9 +444,9 @@ def __getitem__(self, index: slice) -> Self: ... @overload def __getitem__(self, index: Dimension) -> tuple[Dimension, _Rng]: ... - def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: + def __getitem__(self, index: int | slice | Dimension) -> tuple[Dimension, _Rng] | Domain: if isinstance(index, int): - return named_range((self.dims[index], self.ranges[index])) + return (self.dims[index], self.ranges[index]) elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] @@ -455,7 +454,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) - return named_range((self.dims[index_pos], self.ranges[index_pos])) + return (self.dims[index_pos], self.ranges[index_pos]) except ValueError as ex: raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex else: diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index e9bd8db8a5..00c56dc33e 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -182,7 +182,7 @@ def _named_slice_to_named_range( f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'." ) assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int) - return (idx_start_0, common.UnitRange(idx_start_1, idx_stop_1)) + return common.named_range((idx_start_0, common.UnitRange(idx_start_1, idx_stop_1))) if common.is_named_index(idx.start) and idx.stop is None: raise IndexError(f"Upper bound needs to be specified for {idx}.") if common.is_named_index(idx.stop) and idx.start is None: diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index d424d21004..ac38188ea3 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -647,10 +647,10 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] for dim in new_dimensions: if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None: domain_slice.append(slice(None)) - named_ranges.append((dim, field.domain[pos][1])) + named_ranges.append(common.named_range((dim, field.domain.ranges[pos]))) else: domain_slice.append(np.newaxis) - named_ranges.append((dim, common.UnitRange.infinite())) + named_ranges.append(common.named_range((dim, common.UnitRange.infinite()))) return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) @@ -676,7 +676,7 @@ def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdA def _get_slices_from_domain_slice( domain: common.Domain, - domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], + domain_slice: common.Domain | Sequence[Any], ) -> common.RelativeIndexSequence: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index a21ada02cb..ae2407d871 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -52,10 +52,23 @@ def __call__( # type: ignore[override] scan_axis = scan_range.dim all_args = [*args, *kwargs.values()] domain_intersection = _intersect_scan_args(*all_args) - non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) + non_scan_domain = common.Domain(*[ + common.named_range(( + domain_intersection.dims[idx_nr], + domain_intersection.ranges[idx_nr], + )) + for idx_nr, nr in enumerate(domain_intersection.dims) + if nr != scan_axis + ]) out_domain = common.Domain(*[ - scan_range if nr[0] == scan_axis else nr for nr in domain_intersection + scan_range + if nr == scan_axis + else common.named_range(( + domain_intersection.dims[idx_nr], + domain_intersection.ranges[idx_nr], + )) + for idx_nr, nr in enumerate(domain_intersection.dims) ]) if scan_axis not in out_domain.dims: # even if the scan dimension is not in the input, we can scan over it @@ -66,7 +79,7 @@ def __call__( # type: ignore[override] def scan_loop(hpos): acc = self.init - for k in scan_range[1] if self.forward else reversed(scan_range[1]): + for k in scan_range.urange if self.forward else reversed(scan_range.urange): pos = (*hpos, (scan_axis, k)) new_args = [_tuple_at(pos, arg) for arg in args] new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()} @@ -124,8 +137,14 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): return op(*args, **kwargs) -def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: - vertical_dim_filtered = [nr for nr in domain.dims if nr.kind == common.DimensionKind.VERTICAL] +def _get_vertical_range( + domain: common.Domain, +) -> common.NamedRange | eve.NothingType: + vertical_dim_filtered = [ + common.named_range((domain.dims[idx_nr], domain.ranges[idx_nr])) + for idx_nr, nr in enumerate(domain.dims) + if nr.kind == common.DimensionKind.VERTICAL + ] assert len(vertical_dim_filtered) <= 1 return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index f77427ade7..f73c1a8878 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1092,10 +1092,11 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - def restrict(self, item: common.AnyIndexSpec) -> common.Field: if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off - d, r = item.dims + assert common.is_named_index(item[0]) # for mypy errors on multiple lines below + d, r = item[0] assert d == self._dimension assert isinstance(r, core_defs.INTEGRAL_TYPES) - return self.__class__(self._dimension, r) # type: ignore[arg-type] # not sure why the assert above does not work + return self.__class__(self._dimension, r) # TODO set a domain... raise NotImplementedError() From bf96eda7918ff7ada68fa43b8a2f1dde5e31cdb6 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 14 Mar 2024 14:01:20 +0100 Subject: [PATCH 09/23] further edits --- src/gt4py/next/embedded/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 00c56dc33e..8ad4f0e860 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -47,9 +47,9 @@ def _relative_sub_domain( ) expanded += (slice(None),) * (len(domain) - len(expanded)) for dom, idx in zip(domain, expanded, strict=True): - assert isinstance(dom, common.NamedRange) if isinstance(idx, slice): try: + assert common.is_named_range(dom) sliced = _slice_range(dom.urange, idx) named_ranges.append(common.named_range((dom.dim, sliced))) except IndexError as ex: @@ -59,9 +59,9 @@ def _relative_sub_domain( else: # not in new domain assert common.is_int_index(idx) - assert common.UnitRange.is_finite(dom.urange) - new_index = (dom.urange.start if idx >= 0 else dom.urange.stop) + idx - if new_index < dom.urange.start or new_index >= dom.urange.stop: + assert common.is_named_range(dom) and common.is_finite_named_range(dom.urange) + new_index = (dom.urange.start if idx >= 0 else dom.urange.stop) + idx # type: ignore[attr-defined] # urange attr checked in assert above + if new_index < dom.urange.start or new_index >= dom.urange.stop: # type: ignore[attr-defined] # urange attr checked in assert above raise embedded_exceptions.IndexOutOfBounds( domain=domain, indices=index, index=idx, dim=dom.dim ) @@ -144,7 +144,7 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit def _find_index_of_dim( dim: common.Dimension, - domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], + domain_slice: common.Domain | Sequence[common.NamedIndex | Any], ) -> Optional[int]: if isinstance(domain_slice, common.Domain): for i in range(domain_slice.ndim): From 8fdaefc79f2407ea6a40655c5a270716fb0eef2a Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 14 Mar 2024 15:14:00 +0100 Subject: [PATCH 10/23] edits --- src/gt4py/next/embedded/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 8ad4f0e860..f37b30e870 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -47,9 +47,9 @@ def _relative_sub_domain( ) expanded += (slice(None),) * (len(domain) - len(expanded)) for dom, idx in zip(domain, expanded, strict=True): + dom = common.named_range(dom) if isinstance(idx, slice): try: - assert common.is_named_range(dom) sliced = _slice_range(dom.urange, idx) named_ranges.append(common.named_range((dom.dim, sliced))) except IndexError as ex: @@ -127,7 +127,7 @@ def _expand_ellipsis( def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: if slice_obj == slice(None): - return common.UnitRange(input_range.start, input_range.stop) + return input_range start = ( input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop From 7abeb4bdc55391726b099fc8e7c77fbbac7106f2 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 14 Mar 2024 16:40:45 +0100 Subject: [PATCH 11/23] some edits --- src/gt4py/next/embedded/common.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index f37b30e870..bf2345e5fb 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -47,23 +47,26 @@ def _relative_sub_domain( ) expanded += (slice(None),) * (len(domain) - len(expanded)) for dom, idx in zip(domain, expanded, strict=True): - dom = common.named_range(dom) + new_dom = common.named_range(dom) if isinstance(idx, slice): try: - sliced = _slice_range(dom.urange, idx) - named_ranges.append(common.named_range((dom.dim, sliced))) + sliced = _slice_range(new_dom.urange, idx) + named_ranges.append(common.named_range((new_dom.dim, sliced))) except IndexError as ex: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=idx, dim=dom.dim + domain=domain, + indices=index, + index=idx, + dim=new_dom.dim, ) from ex else: # not in new domain assert common.is_int_index(idx) - assert common.is_named_range(dom) and common.is_finite_named_range(dom.urange) - new_index = (dom.urange.start if idx >= 0 else dom.urange.stop) + idx # type: ignore[attr-defined] # urange attr checked in assert above - if new_index < dom.urange.start or new_index >= dom.urange.stop: # type: ignore[attr-defined] # urange attr checked in assert above + assert common.is_finite_named_range(new_dom.urange) + new_index = (new_dom.urange.start if idx >= 0 else new_dom.urange.stop) + idx # type: ignore[attr-defined] # urange attr checked in assert above + if new_index < new_dom.urange.start or new_index >= new_dom.urange.stop: # type: ignore[attr-defined] # urange attr checked in assert above raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=idx, dim=dom.dim + domain=domain, indices=index, index=idx, dim=new_dom.dim ) return common.Domain(*named_ranges) From d942b0580bcce4d68fa6f170cc8861f37c2628b9 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Fri, 15 Mar 2024 10:06:17 +0100 Subject: [PATCH 12/23] edits --- src/gt4py/next/common.py | 10 +++++----- src/gt4py/next/embedded/common.py | 23 ++++++++++++----------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index fa1edc2efb..2aa5a30396 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -436,17 +436,17 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: return all(UnitRange.is_finite(rng) for rng in obj.ranges) @overload - def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... + def __getitem__(self, index: int) -> NamedRange: ... @overload def __getitem__(self, index: slice) -> Self: ... @overload - def __getitem__(self, index: Dimension) -> tuple[Dimension, _Rng]: ... + def __getitem__(self, index: Dimension) -> NamedRange: ... - def __getitem__(self, index: int | slice | Dimension) -> tuple[Dimension, _Rng] | Domain: + def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: if isinstance(index, int): - return (self.dims[index], self.ranges[index]) + return NamedRange(dim=self.dims[index], urange=self.ranges[index]) elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] @@ -454,7 +454,7 @@ def __getitem__(self, index: int | slice | Dimension) -> tuple[Dimension, _Rng] elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) - return (self.dims[index_pos], self.ranges[index_pos]) + return NamedRange(dim=self.dims[index_pos], urange=self.ranges[index_pos]) except ValueError as ex: raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex else: diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index bf2345e5fb..9918fa5a95 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -47,26 +47,26 @@ def _relative_sub_domain( ) expanded += (slice(None),) * (len(domain) - len(expanded)) for dom, idx in zip(domain, expanded, strict=True): - new_dom = common.named_range(dom) + assert isinstance(dom, common.NamedRange) if isinstance(idx, slice): try: - sliced = _slice_range(new_dom.urange, idx) - named_ranges.append(common.named_range((new_dom.dim, sliced))) + sliced = _slice_range(dom.urange, idx) + named_ranges.append(common.named_range((dom.dim, sliced))) except IndexError as ex: raise embedded_exceptions.IndexOutOfBounds( domain=domain, indices=index, index=idx, - dim=new_dom.dim, + dim=dom.dim, ) from ex else: # not in new domain assert common.is_int_index(idx) - assert common.is_finite_named_range(new_dom.urange) - new_index = (new_dom.urange.start if idx >= 0 else new_dom.urange.stop) + idx # type: ignore[attr-defined] # urange attr checked in assert above - if new_index < new_dom.urange.start or new_index >= new_dom.urange.stop: # type: ignore[attr-defined] # urange attr checked in assert above + assert common.is_finite_named_range(dom.urange) + new_index = (dom.urange.start if idx >= 0 else dom.urange.stop) + idx + if new_index < dom.urange.start or new_index >= dom.urange.stop: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=idx, dim=new_dom.dim + domain=domain, indices=index, index=idx, dim=dom.dim ) return common.Domain(*named_ranges) @@ -153,9 +153,10 @@ def _find_index_of_dim( for i in range(domain_slice.ndim): if domain_slice.dims[i] == dim: return i - for i, (d, _) in enumerate(domain_slice): - if dim == d: - return i + else: + for i, (d, _) in enumerate(domain_slice): + if dim == d: + return i return None From 37cebd16e347bd2242c9319278cf82d34facd4db Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Fri, 15 Mar 2024 12:45:31 +0100 Subject: [PATCH 13/23] edits --- src/gt4py/next/common.py | 26 +++++++++++------------ src/gt4py/next/embedded/common.py | 21 +++++++++++++----- src/gt4py/next/embedded/nd_array_field.py | 25 +++++++++++++--------- src/gt4py/next/iterator/embedded.py | 4 ++-- 4 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index d9a8d0e091..a545e1156b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -248,15 +248,6 @@ def __str__(self) -> str: FiniteUnitRange: TypeAlias = UnitRange[int, int] - -RangeLike: TypeAlias = ( - UnitRange - | range - | tuple[core_defs.IntegralScalar, core_defs.IntegralScalar] - | core_defs.IntegralScalar - | None -) - _Rng = TypeVar( "_Rng", UnitRange[int, int], @@ -265,6 +256,14 @@ def __str__(self) -> str: UnitRange[Infinity, Infinity], ) +RangeLike: TypeAlias = ( + _Rng + | range + | tuple[core_defs.IntegralScalar, core_defs.IntegralScalar] + | core_defs.IntegralScalar + | None +) + def unit_range(r: RangeLike) -> UnitRange: if isinstance(r, UnitRange): @@ -413,9 +412,10 @@ def __init__( object.__setattr__(self, "ranges", tuple(ranges)) else: if not all(is_named_range(arg) for arg in args): - raise ValueError( - f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." - ) + args = tuple(named_range(arg) for arg in args) + # raise ValueError( + # f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." + # ) dims_new = (arg.dim for arg in args) if args else () ranges_new = (arg.urange for arg in args) if args else () object.__setattr__(self, "dims", tuple(dims_new)) @@ -568,7 +568,7 @@ def domain(domain_like: DomainLike) -> Domain: return Domain( dims=tuple(domain_like.keys()), ranges=tuple( - UnitRange(0, s) # type: ignore[arg-type] # type of `s` is checked in condition + UnitRange(0, s) for s in domain_like.values() ), ) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 492f302e6a..21fe04a8ae 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -143,13 +143,24 @@ def restrict_to_intersection( """ ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,) intersection_without_ignore_dims = domain_intersection(*[ - common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple]) + common.Domain(*[ + common.named_range((domain.dims[i], domain.ranges[i])) + for i in range(domain.ndim) + if domain.dims[i] not in ignore_dims_tuple + ]) for domain in domains ]) return tuple( common.Domain(*[ - (d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1]) - for d, r in domain + common.named_range(( + domain.dims[i], + domain.ranges[i] + if domain.dims[i] in ignore_dims_tuple + else intersection_without_ignore_dims.ranges[ + intersection_without_ignore_dims.dims.index(domain.dims[i]) + ], + )) + for i in range(domain.ndim) ]) for domain in domains ) @@ -197,8 +208,8 @@ def _find_index_of_dim( if domain_slice.dims[i] == dim: return i else: - for i, (d, _) in enumerate(domain_slice): - if dim == d: + for i, d_slice in enumerate(domain_slice): + if common.is_named_range(d_slice) and dim == d_slice.dim: return i return None diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 3ef5a04764..0f9e7c5572 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -135,10 +135,10 @@ def asnumpy(self) -> np.ndarray: return np.asarray(self._ndarray) def as_scalar(self) -> core_defs.ScalarT: - if self.domain.ndim != 0: - raise ValueError( - "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." - ) + # if self.domain.ndim != 0: + # raise ValueError( + # "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." + # ) return self.ndarray.item() @property @@ -229,6 +229,10 @@ def remap( __call__ = remap # type: ignore[assignment] def restrict(self, index: common.AnyIndexSpec) -> common.Field: + if isinstance(index, tuple): + index = tuple( + common.named_range(ind) if isinstance(ind, tuple) else ind for ind in index + ) # type: ignore[union-attr, assignment] new_domain, buffer_slice = self._slice(index) new_buffer = self.ndarray[buffer_slice] new_buffer = self.__class__.array_ns.asarray(new_buffer) @@ -603,14 +607,15 @@ def _intersect_fields( def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]: if not domains: return common.Domain() - dim_start = domains[0][dim][1].start + dim_start = domains[0][domains[0].dims.index(dim)].urange.start dim_stop = dim_start for domain in domains: - if not domain[dim][1].start == dim_stop: + dim_idx = domain.dims.index(dim) + if not domain.ranges[dim_idx].start == dim_stop: return None else: - dim_stop = domain[dim][1].stop - return domains[0].replace(dim, (dim, common.UnitRange(dim_start, dim_stop))) + dim_stop = domain.ranges[dim_idx].stop + return domains[0].replace(dim, common.named_range((dim, common.UnitRange(dim_start, dim_stop)))) def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: @@ -688,7 +693,7 @@ def _concat_where( if transformed: return _concat(*transformed, dim=mask_dim) else: - result_domain = common.Domain((mask_dim, common.UnitRange(0, 0))) + result_domain = common.Domain(common.named_range((mask_dim, common.UnitRange(0, 0)))) result_array = xp.empty(result_domain.shape) return cls_.from_array(result_array, domain=result_domain) @@ -877,7 +882,7 @@ def _get_slices_from_domain_slice( pos := embedded_common._find_index_of_dim(domain.dims[pos_old], domain_slice) ) is not None: index_or_range = ( - domain_slice[pos][1] + domain_slice[pos].urange if isinstance(domain_slice, tuple) else domain_slice.ranges[pos] ) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index f73c1a8878..d5b922a36c 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -914,10 +914,10 @@ def _translate_named_indices( assert common.is_int_index( v[0] ) # derefing a concrete element in a sparse field, not a slice - domain_slice.append((d, v[0])) + domain_slice.append(common.named_range((d, v[0]))) else: assert common.is_int_index(v) - domain_slice.append((d, v)) + domain_slice.append(common.named_range((d, common.UnitRange(v, v + 1)))) return tuple(domain_slice) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: From 7da5a623a45662334b33a6f18869b7bcecbc15ba Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Fri, 15 Mar 2024 13:06:12 +0100 Subject: [PATCH 14/23] placed error messages back --- src/gt4py/next/common.py | 6 +++--- src/gt4py/next/embedded/nd_array_field.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index a545e1156b..c0690bd5c4 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -413,9 +413,9 @@ def __init__( else: if not all(is_named_range(arg) for arg in args): args = tuple(named_range(arg) for arg in args) - # raise ValueError( - # f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." - # ) + raise ValueError( + f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." + ) dims_new = (arg.dim for arg in args) if args else () ranges_new = (arg.urange for arg in args) if args else () object.__setattr__(self, "dims", tuple(dims_new)) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 0f9e7c5572..a75f0ff54e 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -135,10 +135,10 @@ def asnumpy(self) -> np.ndarray: return np.asarray(self._ndarray) def as_scalar(self) -> core_defs.ScalarT: - # if self.domain.ndim != 0: - # raise ValueError( - # "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." - # ) + if self.domain.ndim != 0: + raise ValueError( + "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." + ) return self.ndarray.item() @property From 10903d0e6a836fc7818b0b1b7f8faf21430a189f Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Fri, 15 Mar 2024 14:08:35 +0100 Subject: [PATCH 15/23] some edits --- src/gt4py/next/common.py | 16 +++++++--------- src/gt4py/next/embedded/nd_array_field.py | 2 +- src/gt4py/next/embedded/operators.py | 4 +++- src/gt4py/next/iterator/embedded.py | 4 +++- .../embedded_tests/test_nd_array_field.py | 10 ++++++---- 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index c0690bd5c4..905cbe60f4 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -412,10 +412,11 @@ def __init__( object.__setattr__(self, "ranges", tuple(ranges)) else: if not all(is_named_range(arg) for arg in args): + # TODO: put error back args = tuple(named_range(arg) for arg in args) - raise ValueError( - f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." - ) + # raise ValueError( + # f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." + # ) dims_new = (arg.dim for arg in args) if args else () ranges_new = (arg.urange for arg in args) if args else () object.__setattr__(self, "dims", tuple(dims_new)) @@ -491,10 +492,10 @@ def __and__(self, other: Domain) -> Domain: _broadcast_ranges(broadcast_dims, other.dims, other.ranges), ) ) - return Domain(dims=broadcast_dims, ranges=intersected_ranges) + return Domain(dims=broadcast_dims, ranges=intersected_ranges) # TODO def __str__(self) -> str: - return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" + return f"Domain({', '.join(f'{e.dim}={e.urange}' for e in self)})" def dim_index(self, dim: Dimension) -> Optional[int]: return self.dims.index(dim) if dim in self.dims else None @@ -567,10 +568,7 @@ def domain(domain_like: DomainLike) -> Domain: if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()): return Domain( dims=tuple(domain_like.keys()), - ranges=tuple( - UnitRange(0, s) - for s in domain_like.values() - ), + ranges=tuple(UnitRange(0, s) for s in domain_like.values()), ) return Domain( dims=tuple(domain_like.keys()), diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index a75f0ff54e..d96d8650c4 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -248,7 +248,7 @@ def __setitem__( target_domain, target_slice = self._slice(index) if common.is_field(value): - if not value.domain == target_domain: + if value.domain != target_domain: raise ValueError( f"Incompatible 'Domain' in assignment. Source domain = '{value.domain}', target domain = '{target_domain}'." ) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 0156eb7096..72dc569081 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -128,7 +128,9 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) with embedded_context.new_context(**new_context_kwargs) as ctx: - res = ctx.run(op, *args, **kwargs) + res = ctx.run( + op, *args, **kwargs + ) # TODO res output with wrong domain for test_unstructured_shift _tuple_assign_field( out, res, diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index d5b922a36c..a0f8da830e 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -921,7 +921,9 @@ def _translate_named_indices( return tuple(domain_slice) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: - return self._ndarrayfield[self._translate_named_indices(named_indices)].as_scalar() + # TODO: change this back + return self._ndarrayfield[self._translate_named_indices(named_indices)].ndarray.item() + # return self._ndarrayfield[self._translate_named_indices(named_indices)].as_scalar() def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if common.is_mutable_field(self._ndarrayfield): diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 7c932533a6..cef1a7c016 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -21,7 +21,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import Dimension, Domain, UnitRange +from gt4py.next.common import Dimension, Domain, UnitRange, NamedRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -430,7 +430,7 @@ def test_field_broadcast(new_dims, field, expected_domain): @pytest.mark.parametrize( "domain_slice", [ - ((D0, UnitRange(0, 10)),), + (NamedRange(D0, UnitRange(0, 10)),), common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)), ], ) @@ -446,7 +446,7 @@ def test_get_slices_with_named_index(): field_domain = common.Domain( dims=(D0, D1, D2), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) ) - named_index = ((D0, UnitRange(0, 10)), (D1, 2), (D2, 3)) + named_index = (NamedRange(D0, UnitRange(0, 10)), (D1, 2), (D2, 3)) slices = _get_slices_from_domain_slice(field_domain, named_index) assert slices == (slice(0, 10, None), 2, 3) @@ -513,7 +513,9 @@ def test_absolute_indexing_dim_sliced(): ) field = common._field(np.ones((5, 10, 15)), domain=domain) indexed_field_1 = field[D1(8) : D1(10), D0(5) : D0(9)] - expected = field[(D0, UnitRange(5, 9)), (D1, UnitRange(8, 10))] + expected = field[ + NamedRange(dim=D0, urange=UnitRange(5, 9)), NamedRange(dim=D1, urange=UnitRange(8, 10)) + ] assert common.is_field(indexed_field_1) assert indexed_field_1 == expected From 8d2e8ab3719f7ab2042c611a7f7718404c565bd5 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 18 Mar 2024 18:01:57 +0000 Subject: [PATCH 16/23] continue.. --- src/gt4py/next/common.py | 35 +++---- src/gt4py/next/embedded/common.py | 74 +++++---------- .../embedded_tests/test_nd_array_field.py | 18 ++-- tests/next_tests/unit_tests/test_common.py | 94 ++++++++++--------- 4 files changed, 106 insertions(+), 115 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 905cbe60f4..c3dba527ad 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -294,6 +294,9 @@ class NamedRange(Generic[_Rng]): dim: Dimension urange: _Rng + def __iter__(self): + return iter((self.dim, self.urange)) + def __str__(self) -> str: return f"{self.dim}={self.urange}" @@ -319,12 +322,12 @@ def is_int_index(p: Any) -> TypeGuard[IntIndex]: return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) -def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: - return ( - isinstance(v, NamedRange) - and isinstance(v.dim, Dimension) - and isinstance(v.urange, UnitRange) - ) +# def isinstance(v: AnyIndexSpec, common.NamedRange) -> TypeGuard[NamedRange]: +# return ( +# isinstance(v, NamedRange) +# and isinstance(v.dim, Dimension) +# and isinstance(v.urange, UnitRange) +# ) def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: @@ -344,7 +347,7 @@ def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]: def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: return ( is_int_index(v) - or is_named_range(v) + or isinstance(v, NamedRange) or is_named_index(v) or isinstance(v, slice) or v is Ellipsis @@ -352,7 +355,9 @@ def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: - return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v) + return isinstance(v, Sequence) and all( + isinstance(e, NamedRange) or is_named_index(e) for e in v + ) def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]: @@ -374,7 +379,7 @@ def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: @dataclasses.dataclass(frozen=True, init=False) -class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]): +class Domain(Sequence[NamedRange[_Rng]], Generic[_Rng]): """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" dims: tuple[Dimension, ...] @@ -411,12 +416,10 @@ def __init__( object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) else: - if not all(is_named_range(arg) for arg in args): - # TODO: put error back - args = tuple(named_range(arg) for arg in args) - # raise ValueError( - # f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." - # ) + if not all(isinstance(arg, NamedRange) for arg in args): + raise ValueError( + f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." + ) dims_new = (arg.dim for arg in args) if args else () ranges_new = (arg.urange for arg in args) if args else () object.__setattr__(self, "dims", tuple(dims_new)) @@ -511,7 +514,7 @@ def insert(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return self.replace(index, *named_ranges) def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: - assert all(is_named_range(nr) for nr in named_ranges) + assert all(isinstance(nr, NamedRange) for nr in named_ranges) if isinstance(index, Dimension): dim_index = self.dim_index(index) if dim_index is None: diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 21fe04a8ae..d4dddcec45 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -46,27 +46,24 @@ def _relative_sub_domain( f"Can not access dimension with index {index} of 'Field' with {len(domain)} dimensions." ) expanded += (slice(None),) * (len(domain) - len(expanded)) - for dom, idx in zip(domain, expanded, strict=True): - assert isinstance(dom, common.NamedRange) + # TODO undo these changes + for nr, idx in zip(domain, expanded, strict=True): if isinstance(idx, slice): try: - sliced = _slice_range(dom.urange, idx) - named_ranges.append(common.named_range((dom.dim, sliced))) + sliced = _slice_range(nr.urange, idx) + named_ranges.append(common.NamedRange(nr.dim, sliced)) except IndexError as ex: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, - indices=index, - index=idx, - dim=dom.dim, + domain=domain, indices=index, index=idx, dim=dim ) from ex else: # not in new domain assert common.is_int_index(idx) - assert common.is_finite_named_range(dom.urange) - new_index = (dom.urange.start if idx >= 0 else dom.urange.stop) + idx - if new_index < dom.urange.start or new_index >= dom.urange.stop: + assert common.UnitRange.is_finite(nr.urange) + new_index = (nr.urange.start if idx >= 0 else nr.urange.stop) + idx + if new_index < nr.urange.start or new_index >= nr.urange.stop: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=idx, dim=dom.dim + domain=domain, indices=index, index=idx, dim=dim ) return common.Domain(*named_ranges) @@ -76,30 +73,27 @@ def _absolute_sub_domain( domain: common.Domain, index: common.AbsoluteIndexSequence ) -> common.Domain: named_ranges: list[common.NamedRange] = [] - for i in range(domain.ndim): - if (pos := _find_index_of_dim(domain.dims[i], index)) is not None: + for i, nr in enumerate(domain): + if (pos := _find_index_of_dim(nr.dim, index)) is not None: named_idx = index[pos] - if not isinstance(named_idx, common.NamedRange): - idx = named_idx[1] - else: - idx = named_idx.urange + idx = named_idx[1] if isinstance(idx, common.UnitRange): - if not idx <= domain.ranges[i]: + if not idx <= nr.urange: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=named_idx, dim=domain.dims[i] + domain=domain, indices=index, index=named_idx, dim=dim ) - named_ranges.append(common.named_range((domain.dims[i], idx))) + named_ranges.append((nr.dim, idx)) else: # not in new domain assert common.is_int_index(idx) - if idx < domain.ranges[i].start or idx >= domain.ranges[i].stop: + if idx < nr.urange.start or idx >= rng.stop: raise embedded_exceptions.IndexOutOfBounds( - domain=domain, indices=index, index=named_idx, dim=domain.dims[i] + domain=domain, indices=index, index=named_idx, dim=dim ) else: # dimension not mentioned in slice - named_ranges.append(common.named_range((domain.dims[i], domain.ranges[i]))) + named_ranges.append(common.NamedRange(dim, domain.ranges[i])) return common.Domain(*named_ranges) @@ -143,24 +137,13 @@ def restrict_to_intersection( """ ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,) intersection_without_ignore_dims = domain_intersection(*[ - common.Domain(*[ - common.named_range((domain.dims[i], domain.ranges[i])) - for i in range(domain.ndim) - if domain.dims[i] not in ignore_dims_tuple - ]) + common.Domain(*[nr for nr in domain if nr.dim not in ignore_dims_tuple]) for domain in domains ]) return tuple( common.Domain(*[ - common.named_range(( - domain.dims[i], - domain.ranges[i] - if domain.dims[i] in ignore_dims_tuple - else intersection_without_ignore_dims.ranges[ - intersection_without_ignore_dims.dims.index(domain.dims[i]) - ], - )) - for i in range(domain.ndim) + (nr if nr.dim in ignore_dims_tuple else intersection_without_ignore_dims[nr.dim].urange) + for nr in domain ]) for domain in domains ) @@ -201,16 +184,11 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit def _find_index_of_dim( dim: common.Dimension, - domain_slice: common.Domain | Sequence[common.NamedIndex | Any], + domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], ) -> Optional[int]: - if isinstance(domain_slice, common.Domain): - for i in range(domain_slice.ndim): - if domain_slice.dims[i] == dim: - return i - else: - for i, d_slice in enumerate(domain_slice): - if common.is_named_range(d_slice) and dim == d_slice.dim: - return i + for i, (d, _) in enumerate(domain_slice): + if dim == d: + return i return None @@ -240,7 +218,7 @@ def _named_slice_to_named_range( f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'." ) assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int) - return common.named_range((idx_start_0, common.UnitRange(idx_start_1, idx_stop_1))) + return common.NamedRange(idx_start_0, common.UnitRange(idx_start_1, idx_stop_1)) if common.is_named_index(idx.start) and idx.stop is None: raise IndexError(f"Upper bound needs to be specified for {idx}.") if common.is_named_index(idx.stop) and idx.start is None: diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index cef1a7c016..67eb203371 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -570,23 +570,27 @@ def test_absolute_indexing_value_return(): ( (slice(None, 5), slice(None, 2)), (5, 2), - Domain((D0, UnitRange(5, 10)), (D1, UnitRange(2, 4))), + Domain(NamedRange(D0, UnitRange(5, 10)), NamedRange(D1, UnitRange(2, 4))), + ), + ( + (slice(None, 5),), + (5, 10), + Domain(NamedRange(D0, UnitRange(5, 10)), NamedRange(D1, UnitRange(2, 12))), ), - ((slice(None, 5),), (5, 10), Domain((D0, UnitRange(5, 10)), (D1, UnitRange(2, 12)))), ( (Ellipsis, 1), (10,), - Domain((D0, UnitRange(5, 15))), + Domain(NamedRange(D0, UnitRange(5, 15))), ), ( (slice(2, 3), slice(5, 7)), (1, 2), - Domain((D0, UnitRange(7, 8)), (D1, UnitRange(7, 9))), + Domain(NamedRange(D0, UnitRange(7, 8)), NamedRange(D1, UnitRange(7, 9))), ), ( (slice(1, 2), 0), (1,), - Domain((D0, UnitRange(6, 7))), + Domain(NamedRange(D0, UnitRange(6, 7))), ), ], ) @@ -695,7 +699,9 @@ def test_field_unsupported_index(index): ((1, slice(None)), np.ones((10,)) * 42.0), ( (1, slice(None)), - common._field(np.ones((10,)) * 42.0, domain=common.Domain((D1, UnitRange(0, 10)))), + common._field( + np.ones((10,)) * 42.0, domain=common.Domain(NamedRange(D1, UnitRange(0, 10))) + ), ), ], ) diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index ce940131c3..4173e159d4 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -24,6 +24,7 @@ UnitRange, domain, named_range, + NamedRange, promote_dims, unit_range, ) @@ -261,10 +262,10 @@ def test_domain_length(a_domain): "empty_domain, expected", [ (Domain(), False), - (Domain((IDim, UnitRange(0, 10))), False), - (Domain((IDim, UnitRange(0, 0))), True), - (Domain((IDim, UnitRange(0, 0)), (JDim, UnitRange(0, 1))), True), - (Domain((IDim, UnitRange(0, 1)), (JDim, UnitRange(0, 0))), True), + (Domain(NamedRange(IDim, UnitRange(0, 10))), False), + (Domain(NamedRange(IDim, UnitRange(0, 0))), True), + (Domain(NamedRange(IDim, UnitRange(0, 0)), NamedRange(JDim, UnitRange(0, 1))), True), + (Domain(NamedRange(IDim, UnitRange(0, 1)), NamedRange(JDim, UnitRange(0, 0))), True), ], ) def test_empty_domain(empty_domain, expected): @@ -436,89 +437,92 @@ def test_domain_pop(): # Valid index and named ranges ( 0, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), Domain( - (Dimension("X"), UnitRange(100, 110)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), ), ( 1, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("X"), UnitRange(100, 110)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), ), ( -1, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("X"), UnitRange(100, 110)), ), ), ( Dimension("J"), - [(Dimension("X"), UnitRange(100, 110)), (Dimension("Z"), UnitRange(100, 110))], + [ + NamedRange(Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("Z"), UnitRange(100, 110)), + ], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("X"), UnitRange(100, 110)), - (Dimension("Z"), UnitRange(100, 110)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("Z"), UnitRange(100, 110)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), ), # Invalid indices ( 3, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), IndexError, ), ( -4, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), IndexError, ), ( Dimension("Foo"), - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), ValueError, ), From 47a1fd487e314f349f5a752d4bebec1fd9a5d379 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 19 Mar 2024 13:49:35 +0000 Subject: [PATCH 17/23] finish first round of refactoring --- src/gt4py/next/common.py | 66 +++++++++---------- src/gt4py/next/embedded/common.py | 47 ++++++------- src/gt4py/next/embedded/nd_array_field.py | 48 +++++--------- src/gt4py/next/embedded/operators.py | 2 +- src/gt4py/next/iterator/embedded.py | 21 +++--- .../unit_tests/embedded_tests/test_common.py | 38 +++++------ .../embedded_tests/test_nd_array_field.py | 35 +++++----- .../iterator_tests/test_embedded_internals.py | 4 +- tests/next_tests/unit_tests/test_common.py | 6 +- 9 files changed, 125 insertions(+), 142 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index c3dba527ad..1c56b36602 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -34,6 +34,7 @@ ClassVar, Final, Generic, + NamedTuple, Never, Optional, ParamSpec, @@ -84,7 +85,7 @@ def __str__(self): return f"{self.value}[{self.kind}]" def __call__(self, val: int) -> NamedIndex: - return self, val + return NamedIndex(self, val) class Infinity(enum.Enum): @@ -289,21 +290,26 @@ def unit_range(r: RangeLike) -> UnitRange: raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") -@dataclasses.dataclass(frozen=True) -class NamedRange(Generic[_Rng]): +class NamedRange(NamedTuple, Generic[_Rng]): dim: Dimension - urange: _Rng - - def __iter__(self): - return iter((self.dim, self.urange)) + unit_range: _Rng def __str__(self) -> str: - return f"{self.dim}={self.urange}" + return f"{self.dim}={self.unit_range}" IntIndex: TypeAlias = int | core_defs.IntegralScalar -NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple -FiniteNamedRange: TypeAlias = NamedRange[FiniteUnitRange] # TODO: convert to NamedTuple + + +class NamedIndex(NamedTuple): + dim: Dimension + index: IntIndex # type: ignore[assignment] # overriding tuple.index + + def __str__(self) -> str: + return f"{self.dim}={self.index}" + + +FiniteNamedRange: TypeAlias = NamedRange[FiniteUnitRange] RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/ AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice @@ -322,33 +328,21 @@ def is_int_index(p: Any) -> TypeGuard[IntIndex]: return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) -# def isinstance(v: AnyIndexSpec, common.NamedRange) -> TypeGuard[NamedRange]: -# return ( -# isinstance(v, NamedRange) -# and isinstance(v.dim, Dimension) -# and isinstance(v.urange, UnitRange) -# ) - - def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: - return UnitRange.is_finite(v.urange) + return UnitRange.is_finite(v.unit_range) -def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedIndex]: - return ( - isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) +def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[slice]: + return isinstance(obj, slice) and ( + isinstance(obj.start, NamedIndex) and isinstance(obj.stop, NamedIndex) ) -def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]: - return isinstance(obj, slice) and (is_named_index(obj.start) and is_named_index(obj.stop)) - - def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: return ( is_int_index(v) or isinstance(v, NamedRange) - or is_named_index(v) + or isinstance(v, NamedIndex) or isinstance(v, slice) or v is Ellipsis ) @@ -356,7 +350,7 @@ def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: return isinstance(v, Sequence) and all( - isinstance(e, NamedRange) or is_named_index(e) for e in v + isinstance(e, NamedRange) or isinstance(e, NamedIndex) for e in v ) @@ -375,7 +369,9 @@ def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: - return NamedRange(dim=v[0], urange=unit_range(v[1])) + if isinstance(v, NamedRange): + return v + return NamedRange(v[0], unit_range(v[1])) @dataclasses.dataclass(frozen=True, init=False) @@ -421,7 +417,7 @@ def __init__( f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." ) dims_new = (arg.dim for arg in args) if args else () - ranges_new = (arg.urange for arg in args) if args else () + ranges_new = (arg.unit_range for arg in args) if args else () object.__setattr__(self, "dims", tuple(dims_new)) object.__setattr__(self, "ranges", tuple(ranges_new)) @@ -458,7 +454,7 @@ def __getitem__(self, index: Dimension) -> NamedRange: ... def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: if isinstance(index, int): - return NamedRange(dim=self.dims[index], urange=self.ranges[index]) + return NamedRange(dim=self.dims[index], unit_range=self.ranges[index]) elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] @@ -466,7 +462,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) - return NamedRange(dim=self.dims[index_pos], urange=self.ranges[index_pos]) + return NamedRange(dim=self.dims[index_pos], unit_range=self.ranges[index_pos]) except ValueError as ex: raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex else: @@ -498,7 +494,7 @@ def __and__(self, other: Domain) -> Domain: return Domain(dims=broadcast_dims, ranges=intersected_ranges) # TODO def __str__(self) -> str: - return f"Domain({', '.join(f'{e.dim}={e.urange}' for e in self)})" + return f"Domain({', '.join(f'{e}' for e in self)})" def dim_index(self, dim: Dimension) -> Optional[int]: return self.dims.index(dim) if dim in self.dims else None @@ -527,7 +523,7 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: if index < 0: index += len(self.dims) new_dims = (arg.dim for arg in named_ranges) if len(named_ranges) > 0 else () - new_ranges = (arg.urange for arg in named_ranges) if len(named_ranges) > 0 else () + new_ranges = (arg.unit_range for arg in named_ranges) if len(named_ranges) > 0 else () dims = self.dims[:index] + tuple(new_dims) + self.dims[index + 1 :] ranges = self.ranges[:index] + tuple(new_ranges) + self.ranges[index + 1 :] @@ -979,7 +975,7 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'." ) - image_range = image_range.urange + image_range = image_range.unit_range assert isinstance(image_range, UnitRange) return (named_range((self.codomain, image_range - self.offset)),) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index d4dddcec45..e2c330b541 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -46,12 +46,11 @@ def _relative_sub_domain( f"Can not access dimension with index {index} of 'Field' with {len(domain)} dimensions." ) expanded += (slice(None),) * (len(domain) - len(expanded)) - # TODO undo these changes - for nr, idx in zip(domain, expanded, strict=True): + for (dim, rng), idx in zip(domain, expanded, strict=True): if isinstance(idx, slice): try: - sliced = _slice_range(nr.urange, idx) - named_ranges.append(common.NamedRange(nr.dim, sliced)) + sliced = _slice_range(rng, idx) + named_ranges.append(common.NamedRange(dim, sliced)) except IndexError as ex: raise embedded_exceptions.IndexOutOfBounds( domain=domain, indices=index, index=idx, dim=dim @@ -59,9 +58,9 @@ def _relative_sub_domain( else: # not in new domain assert common.is_int_index(idx) - assert common.UnitRange.is_finite(nr.urange) - new_index = (nr.urange.start if idx >= 0 else nr.urange.stop) + idx - if new_index < nr.urange.start or new_index >= nr.urange.stop: + assert common.UnitRange.is_finite(rng) + new_index = (rng.start if idx >= 0 else rng.stop) + idx + if new_index < rng.start or new_index >= rng.stop: raise embedded_exceptions.IndexOutOfBounds( domain=domain, indices=index, index=idx, dim=dim ) @@ -73,21 +72,21 @@ def _absolute_sub_domain( domain: common.Domain, index: common.AbsoluteIndexSequence ) -> common.Domain: named_ranges: list[common.NamedRange] = [] - for i, nr in enumerate(domain): - if (pos := _find_index_of_dim(nr.dim, index)) is not None: + for i, (dim, rng) in enumerate(domain): + if (pos := _find_index_of_dim(dim, index)) is not None: named_idx = index[pos] - idx = named_idx[1] + _, idx = named_idx if isinstance(idx, common.UnitRange): - if not idx <= nr.urange: + if not idx <= rng: raise embedded_exceptions.IndexOutOfBounds( domain=domain, indices=index, index=named_idx, dim=dim ) - named_ranges.append((nr.dim, idx)) + named_ranges.append(common.NamedRange(dim, idx)) else: # not in new domain assert common.is_int_index(idx) - if idx < nr.urange.start or idx >= rng.stop: + if idx < rng.start or idx >= rng.stop: raise embedded_exceptions.IndexOutOfBounds( domain=domain, indices=index, index=named_idx, dim=dim ) @@ -142,7 +141,7 @@ def restrict_to_intersection( ]) return tuple( common.Domain(*[ - (nr if nr.dim in ignore_dims_tuple else intersection_without_ignore_dims[nr.dim].urange) + (nr if nr.dim in ignore_dims_tuple else intersection_without_ignore_dims[nr.dim]) for nr in domain ]) for domain in domains @@ -207,20 +206,16 @@ def _named_slice_to_named_range( ) -> common.NamedRange | common.NamedSlice: assert hasattr(idx, "start") and hasattr(idx, "stop") if common.is_named_slice(idx): - idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = ( - idx.start[0], # type: ignore[attr-defined] - idx.start[1], # type: ignore[attr-defined] - idx.stop[0], # type: ignore[attr-defined] - idx.stop[1], # type: ignore[attr-defined] - ) - if idx_start_0 != idx_stop_0: + start_dim, start_value = idx.start + stop_dim, stop_value = idx.stop + if start_dim != stop_dim: raise IndexError( - f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'." + f"Dimensions slicing mismatch between '{start_dim.value}' and '{stop_dim.value}'." ) - assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int) - return common.NamedRange(idx_start_0, common.UnitRange(idx_start_1, idx_stop_1)) - if common.is_named_index(idx.start) and idx.stop is None: + assert isinstance(start_value, int) and isinstance(stop_value, int) + return common.NamedRange(start_dim, common.UnitRange(start_value, stop_value)) + if isinstance(idx.start, common.NamedIndex) and idx.stop is None: raise IndexError(f"Upper bound needs to be specified for {idx}.") - if common.is_named_index(idx.stop) and idx.start is None: + if isinstance(idx.stop, common.NamedIndex) and idx.start is None: raise IndexError(f"Lower bound needs to be specified for {idx}.") return idx diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index d96d8650c4..1be70ae379 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -24,7 +24,7 @@ from numpy import typing as npt from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar +from gt4py.eve.extended_typing import Never, Optional, ParamSpec, TypeAlias, TypeVar from gt4py.next import common from gt4py.next.embedded import ( common as embedded_common, @@ -196,7 +196,7 @@ def remap( if dim_idx is None: raise ValueError(f"Incompatible index field, expected a field with dimension '{dim}'.") - current_range: common.UnitRange = self.domain.ranges[dim_idx] + current_range: common.UnitRange = self.domain[dim_idx].unit_range new_ranges = connectivity.inverse_image(current_range) new_domain = self.domain.replace(dim_idx, *new_ranges) @@ -229,10 +229,6 @@ def remap( __call__ = remap # type: ignore[assignment] def restrict(self, index: common.AnyIndexSpec) -> common.Field: - if isinstance(index, tuple): - index = tuple( - common.named_range(ind) if isinstance(ind, tuple) else ind for ind in index - ) # type: ignore[union-attr, assignment] new_domain, buffer_slice = self._slice(index) new_buffer = self.ndarray[buffer_slice] new_buffer = self.__class__.array_ns.asarray(new_buffer) @@ -248,7 +244,7 @@ def __setitem__( target_domain, target_slice = self._slice(index) if common.is_field(value): - if value.domain != target_domain: + if not value.domain == target_domain: raise ValueError( f"Incompatible 'Domain' in assignment. Source domain = '{value.domain}', target domain = '{target_domain}'." ) @@ -422,7 +418,7 @@ def inverse_image( f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'." ) - image_range = image_range.urange + image_range = image_range.unit_range assert isinstance(image_range, common.UnitRange) @@ -607,15 +603,14 @@ def _intersect_fields( def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]: if not domains: return common.Domain() - dim_start = domains[0][domains[0].dims.index(dim)].urange.start + dim_start = domains[0][dim].unit_range.start dim_stop = dim_start for domain in domains: - dim_idx = domain.dims.index(dim) - if not domain.ranges[dim_idx].start == dim_stop: + if not domain[dim].unit_range.start == dim_stop: return None else: - dim_stop = domain.ranges[dim_idx].stop - return domains[0].replace(dim, common.named_range((dim, common.UnitRange(dim_start, dim_stop)))) + dim_stop = domain[dim].unit_range.stop + return domains[0].replace(dim, common.NamedRange(dim, common.UnitRange(dim_start, dim_stop))) def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: @@ -693,7 +688,7 @@ def _concat_where( if transformed: return _concat(*transformed, dim=mask_dim) else: - result_domain = common.Domain(common.named_range((mask_dim, common.UnitRange(0, 0)))) + result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0))) result_array = xp.empty(result_domain.shape) return cls_.from_array(result_array, domain=result_domain) @@ -727,11 +722,7 @@ def _builtin_op( axis.value ] # assumes offset and local dimension have same name assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) - new_domain = common.Domain(*[ - common.named_range((field.domain.dims[idx], field.domain.ranges[idx])) - for idx, dim in enumerate(field.domain.dims) - if dim != axis - ]) + new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis]) broadcast_slice = tuple( slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis @@ -829,11 +820,10 @@ def _broadcast(field: common.Field, new_dimensions: Sequence[common.Dimension]) for dim in new_dimensions: if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None: domain_slice.append(slice(None)) - named_ranges.append(common.named_range((dim, field.domain.ranges[pos]))) + named_ranges.append(common.NamedRange(dim, field.domain[pos].unit_range)) else: domain_slice.append(None) # np.newaxis - named_ranges.append(common.named_range((dim, common.UnitRange.infinite()))) - + named_ranges.append(common.NamedRange(dim, common.UnitRange.infinite())) return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) @@ -859,7 +849,7 @@ def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdA def _get_slices_from_domain_slice( domain: common.Domain, - domain_slice: common.Domain | Sequence[Any], + domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex], ) -> common.RelativeIndexSequence: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. @@ -877,15 +867,9 @@ def _get_slices_from_domain_slice( """ slice_indices: list[slice | common.IntIndex] = [] - for pos_old in range(domain.ndim): - if ( - pos := embedded_common._find_index_of_dim(domain.dims[pos_old], domain_slice) - ) is not None: - index_or_range = ( - domain_slice[pos].urange - if isinstance(domain_slice, tuple) - else domain_slice.ranges[pos] - ) + for pos_old, (dim, _) in enumerate(domain): + if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None: + _, index_or_range = domain_slice[pos] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: slice_indices.append(slice(None)) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 72dc569081..869b050fb2 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -79,7 +79,7 @@ def __call__( # type: ignore[override] def scan_loop(hpos): acc = self.init - for k in scan_range.urange if self.forward else reversed(scan_range.urange): + for k in scan_range.unit_range if self.forward else reversed(scan_range.unit_range): pos = (*hpos, (scan_axis, k)) new_args = [_tuple_at(pos, arg) for arg in args] new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()} diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index a0f8da830e..86b7c41a14 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -206,7 +206,7 @@ def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673 column_range: common.NamedRange = column_range_cvar.get() self.data = ( - data if isinstance(data, np.ndarray) else np.full(len(column_range.urange), data) + data if isinstance(data, np.ndarray) else np.full(len(column_range.unit_range), data) ) def __getitem__(self, i: int) -> Any: @@ -748,7 +748,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get().urange + column_range = column_range_cvar.get().unit_range assert column_range is not None col: list[ @@ -825,7 +825,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range.urange)) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range.unit_range)) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -866,7 +866,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get().urange + column_range = column_range_cvar.get().unit_range # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -921,9 +921,7 @@ def _translate_named_indices( return tuple(domain_slice) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: - # TODO: change this back - return self._ndarrayfield[self._translate_named_indices(named_indices)].ndarray.item() - # return self._ndarrayfield[self._translate_named_indices(named_indices)].as_scalar() + return self._ndarrayfield[self._translate_named_indices(named_indices)].as_scalar() def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if common.is_mutable_field(self._ndarrayfield): @@ -1093,8 +1091,11 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - raise NotImplementedError() def restrict(self, item: common.AnyIndexSpec) -> common.Field: - if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off - assert common.is_named_index(item[0]) # for mypy errors on multiple lines below + if ( + common.is_absolute_index_sequence(item) # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off + and all(isinstance(e, common.NamedIndex) for e in item) + ): + assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below d, r = item[0] assert d == self._dimension assert isinstance(r, core_defs.INTEGRAL_TYPES) @@ -1496,7 +1497,7 @@ def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): columns = column_range_cvar.get() assert isinstance(columns, common.NamedRange) - column_range = columns.urange + column_range = columns.unit_range if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 9765273f94..111622ac42 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -17,7 +17,7 @@ import pytest from gt4py.next import common -from gt4py.next.common import UnitRange +from gt4py.next.common import UnitRange, NamedIndex, NamedRange from gt4py.next.embedded import exceptions as embedded_exceptions from gt4py.next.embedded.common import ( _slice_range, @@ -53,12 +53,12 @@ def test_slice_range(rng, slce, expected): [ ([(I, (2, 5))], 1, []), ([(I, (2, 5))], slice(1, 2), [(I, (3, 4))]), - ([(I, (2, 5))], (I, 2), []), - ([(I, (2, 5))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (2, 5))], NamedIndex(I, 2), []), + ([(I, (2, 5))], NamedRange(I, UnitRange(2, 3)), [(I, (2, 3))]), ([(I, (-2, 3))], 1, []), ([(I, (-2, 3))], slice(1, 2), [(I, (-1, 0))]), - ([(I, (-2, 3))], (I, 1), []), - ([(I, (-2, 3))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], NamedIndex(I, 1), []), + ([(I, (-2, 3))], NamedRange(I, UnitRange(2, 3)), [(I, (2, 3))]), ([(I, (-2, 3))], -5, []), ([(I, (-2, 3))], -6, IndexError), ([(I, (-2, 3))], slice(-7, -6), IndexError), @@ -67,10 +67,10 @@ def test_slice_range(rng, slce, expected): ([(I, (-2, 3))], 5, IndexError), ([(I, (-2, 3))], slice(4, 5), [(I, (2, 3))]), ([(I, (-2, 3))], slice(5, 6), IndexError), - ([(I, (-2, 3))], (I, -3), IndexError), - ([(I, (-2, 3))], (I, UnitRange(-3, -2)), IndexError), - ([(I, (-2, 3))], (I, 3), IndexError), - ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), + ([(I, (-2, 3))], NamedIndex(I, -3), IndexError), + ([(I, (-2, 3))], NamedRange(I, UnitRange(-3, -2)), IndexError), + ([(I, (-2, 3))], NamedIndex(I, 3), IndexError), + ([(I, (-2, 3))], NamedRange(I, UnitRange(3, 4)), IndexError), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], 2, @@ -83,32 +83,32 @@ def test_slice_range(rng, slce, expected): ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (I, 2), + NamedIndex(I, 2), [(J, (3, 6)), (K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (I, UnitRange(2, 3)), + NamedRange(I, UnitRange(2, 3)), [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (J, 3), + NamedIndex(J, 3), [(I, (2, 5)), (K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (J, UnitRange(4, 5)), + NamedRange(J, UnitRange(4, 5)), [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - ((J, 3), (I, 2)), + (NamedIndex(J, 3), NamedIndex(I, 2)), [(K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - ((J, UnitRange(4, 5)), (I, 2)), + (NamedRange(J, UnitRange(4, 5)), NamedIndex(I, 2)), [(J, (4, 5)), (K, (4, 7))], ), ( @@ -147,8 +147,8 @@ def test_sub_domain(domain, index, expected): def test_iterate_domain(): domain = common.domain({I: 2, J: 3}) ref = [] - for i in domain[I][1]: - for j in domain[J][1]: + for i in domain[I].unit_range: + for j in domain[J].unit_range: ref.append(((I, i), (J, j))) testee = list(iterate_domain(domain)) @@ -159,10 +159,10 @@ def test_iterate_domain(): @pytest.mark.parametrize( "slices, expected", [ - [slice(I(3), I(4)), ((I, common.UnitRange(3, 4)),)], + [slice(I(3), I(4)), (NamedRange(I, common.UnitRange(3, 4)),)], [ (slice(J(3), J(6)), slice(I(3), I(5))), - ((J, common.UnitRange(3, 6)), (I, common.UnitRange(3, 5))), + (NamedRange(J, common.UnitRange(3, 6)), NamedRange(I, common.UnitRange(3, 5))), ], [slice(I(1), J(7)), IndexError], [ diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 67eb203371..7171bb5ecc 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -21,7 +21,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import Dimension, Domain, UnitRange, NamedRange +from gt4py.next.common import Dimension, Domain, UnitRange, NamedRange, NamedIndex from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -465,34 +465,34 @@ def test_get_slices_invalid_type(): [ ( ( - (D0, UnitRange(7, 9)), - (D1, UnitRange(8, 10)), + NamedRange(D0, UnitRange(7, 9)), + NamedRange(D1, UnitRange(8, 10)), ), (D0, D1, D2), (2, 2, 15), ), ( ( - (D0, UnitRange(7, 9)), - (D2, UnitRange(12, 20)), + NamedRange(D0, UnitRange(7, 9)), + NamedRange(D2, UnitRange(12, 20)), ), (D0, D1, D2), (2, 10, 8), ), (common.Domain(dims=(D0,), ranges=(UnitRange(7, 9),)), (D0, D1, D2), (2, 10, 15)), - (((D0, 8),), (D1, D2), (10, 15)), - (((D1, 9),), (D0, D2), (5, 15)), - (((D2, 11),), (D0, D1), (5, 10)), + ((NamedIndex(D0, 8),), (D1, D2), (10, 15)), + ((NamedIndex(D1, 9),), (D0, D2), (5, 15)), + ((NamedIndex(D2, 11),), (D0, D1), (5, 10)), ( ( - (D0, 8), - (D1, UnitRange(8, 10)), + NamedIndex(D0, 8), + NamedRange(D1, UnitRange(8, 10)), ), (D1, D2), (2, 15), ), - ((D0, 5), (D1, D2), (10, 15)), - ((D0, UnitRange(5, 7)), (D0, D1, D2), (2, 10, 15)), + (NamedIndex(D0, 5), (D1, D2), (10, 15)), + (NamedRange(D0, UnitRange(5, 7)), (D0, D1, D2), (2, 10, 15)), ], ) def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): @@ -514,7 +514,8 @@ def test_absolute_indexing_dim_sliced(): field = common._field(np.ones((5, 10, 15)), domain=domain) indexed_field_1 = field[D1(8) : D1(10), D0(5) : D0(9)] expected = field[ - NamedRange(dim=D0, urange=UnitRange(5, 9)), NamedRange(dim=D1, urange=UnitRange(8, 10)) + NamedRange(dim=D0, unit_range=UnitRange(5, 9)), + NamedRange(dim=D1, unit_range=UnitRange(8, 10)), ] assert common.is_field(indexed_field_1) @@ -527,7 +528,7 @@ def test_absolute_indexing_dim_sliced_single_slice(): ) field = common._field(np.ones((5, 10, 15)), domain=domain) indexed_field_1 = field[D2(11)] - indexed_field_2 = field[(D2, 11)] + indexed_field_2 = field[NamedIndex(D2, 11)] assert common.is_field(indexed_field_1) assert indexed_field_1 == indexed_field_2 @@ -556,7 +557,7 @@ def test_absolute_indexing_value_return(): domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(10, 20), UnitRange(5, 15))) field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) - named_index = ((D0, 12), (D1, 6)) + named_index = (NamedIndex(D0, 12), NamedIndex(D1, 6)) assert common.is_field(field) value = field[named_index] @@ -726,7 +727,7 @@ def test_setitem_wrong_domain(): ) value_incompatible = common._field( - np.ones((10,)) * 42.0, domain=common.Domain((D1, UnitRange(-5, 5))) + np.ones((10,)) * 42.0, domain=common.Domain(NamedRange(D1, UnitRange(-5, 5))) ) with pytest.raises(ValueError, match=r"Incompatible 'Domain'.*"): @@ -759,7 +760,7 @@ def test_connectivity_field_inverse_image(): # Test codomain with pytest.raises(ValueError, match="does not match the codomain dimension"): - e2v_conn.inverse_image((E, UnitRange(1, 2))) + e2v_conn.inverse_image(NamedRange(E, UnitRange(1, 2))) def test_connectivity_field_inverse_image_2d_domain(): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py index 9238cd4f7a..ec6e613529 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py @@ -62,7 +62,9 @@ def test_func(data_a: int, data_b: int): embedded.column_range_cvar.set(range(2, 999)) _run_within_context( lambda: test_func(2, 3), - column_range=(common.Dimension("K", kind=common.DimensionKind.VERTICAL), range(0, 3)), + column_range=common.NamedRange( + common.Dimension("K", kind=common.DimensionKind.VERTICAL), range(0, 3) + ), ) diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 4173e159d4..1aeb51cb30 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -38,7 +38,11 @@ @pytest.fixture def a_domain(): - return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) + return Domain( + NamedRange(IDim, UnitRange(0, 10)), + NamedRange(JDim, UnitRange(5, 15)), + NamedRange(KDim, UnitRange(20, 30)), + ) @pytest.fixture(params=[Infinity.POSITIVE, Infinity.NEGATIVE]) From b36311a74c161a3db1ee717aaf2cc7d550844038 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 19 Mar 2024 13:54:28 +0000 Subject: [PATCH 18/23] cleanup operators --- src/gt4py/next/embedded/operators.py | 31 +++++----------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 869b050fb2..a6b511c840 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -52,23 +52,10 @@ def __call__( # type: ignore[override] scan_axis = scan_range.dim all_args = [*args, *kwargs.values()] domain_intersection = _intersect_scan_args(*all_args) - non_scan_domain = common.Domain(*[ - common.named_range(( - domain_intersection.dims[idx_nr], - domain_intersection.ranges[idx_nr], - )) - for idx_nr, nr in enumerate(domain_intersection.dims) - if nr != scan_axis - ]) + non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr.dim != scan_axis]) out_domain = common.Domain(*[ - scan_range - if nr == scan_axis - else common.named_range(( - domain_intersection.dims[idx_nr], - domain_intersection.ranges[idx_nr], - )) - for idx_nr, nr in enumerate(domain_intersection.dims) + scan_range if nr.dim == scan_axis else nr for nr in domain_intersection ]) if scan_axis not in out_domain.dims: # even if the scan dimension is not in the input, we can scan over it @@ -128,9 +115,7 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) with embedded_context.new_context(**new_context_kwargs) as ctx: - res = ctx.run( - op, *args, **kwargs - ) # TODO res output with wrong domain for test_unstructured_shift + res = ctx.run(op, *args, **kwargs) _tuple_assign_field( out, res, @@ -144,14 +129,8 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): return op(*args, **kwargs) -def _get_vertical_range( - domain: common.Domain, -) -> common.NamedRange | eve.NothingType: - vertical_dim_filtered = [ - common.named_range((domain.dims[idx_nr], domain.ranges[idx_nr])) - for idx_nr, nr in enumerate(domain.dims) - if nr.kind == common.DimensionKind.VERTICAL - ] +def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: + vertical_dim_filtered = [nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL] assert len(vertical_dim_filtered) <= 1 return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING From 35af9687f8597f17d6d87d4728d712796b4265ca Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 19 Mar 2024 13:56:44 +0000 Subject: [PATCH 19/23] cleanup iterator.embedded --- src/gt4py/next/iterator/embedded.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 86b7c41a14..7e4e942d3f 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -908,16 +908,16 @@ def _translate_named_indices( domain_slice: list[common.NamedRange | common.NamedIndex] = [] for d, v in named_indices.items(): if isinstance(v, range): - domain_slice.append(common.named_range((d, common.UnitRange(v.start, v.stop)))) + domain_slice.append(common.NamedRange(d, common.UnitRange(v.start, v.stop))) elif isinstance(v, list): assert len(v) == 1 # only 1 sparse dimension is supported assert common.is_int_index( v[0] ) # derefing a concrete element in a sparse field, not a slice - domain_slice.append(common.named_range((d, v[0]))) + domain_slice.append(common.NamedRange(d, v[0])) else: assert common.is_int_index(v) - domain_slice.append(common.named_range((d, common.UnitRange(v, v + 1)))) + domain_slice.append(common.NamedRange(d, common.UnitRange(v, v + 1))) return tuple(domain_slice) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: @@ -1059,7 +1059,7 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override @property def domain(self) -> common.Domain: if self._cur_index is None: - return common.Domain(common.named_range((self._dimension, common.UnitRange.infinite()))) + return common.Domain(common.NamedRange(self._dimension, common.UnitRange.infinite())) else: return common.Domain() @@ -1550,10 +1550,10 @@ def closure( column = ColumnDescriptor(column_axis.value, domain[column_axis.value]) del domain[column_axis.value] - column_range = common.named_range(( + column_range = common.NamedRange( column_axis, common.UnitRange(column.col_range.start, column.col_range.stop), - )) + ) out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) From 38c1585bd431565356ef5395cba6d8a0f15f4c2e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 19 Mar 2024 14:13:30 +0000 Subject: [PATCH 20/23] fix bug --- src/gt4py/next/embedded/nd_array_field.py | 2 +- src/gt4py/next/iterator/embedded.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 1be70ae379..72a8343dc4 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -137,7 +137,7 @@ def asnumpy(self) -> np.ndarray: def as_scalar(self) -> core_defs.ScalarT: if self.domain.ndim != 0: raise ValueError( - "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." + f"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." ) return self.ndarray.item() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 7e4e942d3f..1e753d24f2 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -914,10 +914,10 @@ def _translate_named_indices( assert common.is_int_index( v[0] ) # derefing a concrete element in a sparse field, not a slice - domain_slice.append(common.NamedRange(d, v[0])) + domain_slice.append(common.NamedIndex(d, v[0])) else: assert common.is_int_index(v) - domain_slice.append(common.NamedRange(d, common.UnitRange(v, v + 1))) + domain_slice.append(common.NamedIndex(d, v)) return tuple(domain_slice) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: @@ -1495,9 +1495,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - columns = column_range_cvar.get() - assert isinstance(columns, common.NamedRange) - column_range = columns.unit_range + column_range = column_range_cvar.get().unit_range if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") From 19476325f3d730a88bae3a47537afd63a3e76331 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 19 Mar 2024 16:52:46 +0000 Subject: [PATCH 21/23] fix doctest --- src/gt4py/next/common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index f5202a86dc..8a80758f84 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -477,10 +477,12 @@ def __and__(self, other: Domain) -> Domain: >>> I = Dimension("I") >>> J = Dimension("J") - >>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6))) + >>> Domain(NamedRange(I, UnitRange(-1, 3))) & Domain(NamedRange(I, UnitRange(1, 6))) Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(1, 3),)) - >>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6))) + >>> Domain(NamedRange(I, UnitRange(-1, 3)), NamedRange(J, UnitRange(2, 4))) & Domain( + ... NamedRange(I, UnitRange(1, 6)) + ... ) Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(1, 3), UnitRange(2, 4))) """ broadcast_dims = tuple(promote_dims(self.dims, other.dims)) From d6103fe2c7d6dfd61ee52cf7ca291928c8868234 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 20 Mar 2024 19:19:40 +0100 Subject: [PATCH 22/23] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Enrique González Paredes --- src/gt4py/next/common.py | 10 ++++------ src/gt4py/next/embedded/common.py | 4 ++-- src/gt4py/next/iterator/embedded.py | 3 +-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 8a80758f84..8d3c2b59ab 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -251,7 +251,7 @@ def __str__(self) -> str: _Rng = TypeVar( "_Rng", - UnitRange[int, int], + FiniteUnitRange, UnitRange[Infinity, int], UnitRange[int, Infinity], UnitRange[Infinity, Infinity], @@ -341,16 +341,14 @@ def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[slice]: def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: return ( is_int_index(v) - or isinstance(v, NamedRange) - or isinstance(v, NamedIndex) - or isinstance(v, slice) + or isinstance(v, (NamedRange, NamedIndex, slice)) or v is Ellipsis ) def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: return isinstance(v, Sequence) and all( - isinstance(e, NamedRange) or isinstance(e, NamedIndex) for e in v + isinstance(e, (NamedRange, NamedIndex)) for e in v ) @@ -493,7 +491,7 @@ def __and__(self, other: Domain) -> Domain: _broadcast_ranges(broadcast_dims, other.dims, other.ranges), ) ) - return Domain(dims=broadcast_dims, ranges=intersected_ranges) # TODO + return Domain(dims=broadcast_dims, ranges=intersected_ranges) def __str__(self) -> str: return f"Domain({', '.join(f'{e}' for e in self)})" diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 8129c12e54..cdb0d3a5fd 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -152,8 +152,8 @@ def restrict_to_intersection( def iterate_domain( domain: common.Domain, ) -> Iterator[tuple[common.NamedIndex]]: - for i in itertools.product(*[list(r) for r in domain.ranges]): - yield tuple(common.NamedIndex(*e) for e in zip(domain.dims, i)) # type: ignore[misc] # trust me, `i` is `tuple[int, ...]` + for idx in itertools.product(*(list(r) for r in domain.ranges)): + yield tuple(common.NamedIndex(d, i) for d, i in zip(domain.dims, idx)) # type: ignore[misc] # trust me, `idx` is `tuple[int, ...]` def _expand_ellipsis( diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 516e52799f..0bc9698c3f 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1095,8 +1095,7 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - def restrict(self, item: common.AnyIndexSpec) -> Self: if ( - common.is_absolute_index_sequence(item) # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off - and all(isinstance(e, common.NamedIndex) for e in item) + isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item) ): assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below d, r = item[0] From 2f54512b4d30a4fa2b3c0625f37aae00f26e9bdc Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 20 Mar 2024 18:23:56 +0000 Subject: [PATCH 23/23] address review comments --- src/gt4py/next/common.py | 23 ++++++++--------------- src/gt4py/next/iterator/embedded.py | 4 +--- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 8d3c2b59ab..2936e4163a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -303,10 +303,10 @@ def __str__(self) -> str: class NamedIndex(NamedTuple): dim: Dimension - index: IntIndex # type: ignore[assignment] # overriding tuple.index + value: IntIndex def __str__(self) -> str: - return f"{self.dim}={self.index}" + return f"{self.dim}={self.value}" FiniteNamedRange: TypeAlias = NamedRange[FiniteUnitRange] @@ -339,17 +339,11 @@ def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[slice]: def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: - return ( - is_int_index(v) - or isinstance(v, (NamedRange, NamedIndex, slice)) - or v is Ellipsis - ) + return is_int_index(v) or isinstance(v, (NamedRange, NamedIndex, slice)) or v is Ellipsis def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: - return isinstance(v, Sequence) and all( - isinstance(e, (NamedRange, NamedIndex)) for e in v - ) + return isinstance(v, Sequence) and all(isinstance(e, (NamedRange, NamedIndex)) for e in v) def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]: @@ -381,7 +375,7 @@ class Domain(Sequence[NamedRange[_Rng]], Generic[_Rng]): def __init__( self, - *args: NamedRange, + *args: NamedRange[_Rng], dims: Optional[Sequence[Dimension]] = None, ranges: Optional[Sequence[_Rng]] = None, ) -> None: @@ -414,10 +408,9 @@ def __init__( raise ValueError( f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." ) - dims_new = (arg.dim for arg in args) if args else () - ranges_new = (arg.unit_range for arg in args) if args else () - object.__setattr__(self, "dims", tuple(dims_new)) - object.__setattr__(self, "ranges", tuple(ranges_new)) + dims, ranges = zip(*args) if args else ((), ()) + object.__setattr__(self, "dims", tuple(dims)) + object.__setattr__(self, "ranges", tuple(ranges)) if len(set(self.dims)) != len(self.dims): raise NotImplementedError(f"Domain dimensions must be unique, not '{self.dims}'.") diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 0bc9698c3f..c9552e7138 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1094,9 +1094,7 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - raise NotImplementedError() def restrict(self, item: common.AnyIndexSpec) -> Self: - if ( - isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item) - ): + if isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item): assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below d, r = item[0] assert d == self._dimension