Skip to content

Commit

Permalink
feat[next]: new domain slice syntax (#1453)
Browse files Browse the repository at this point in the history
New domain slice syntax, e.g. f[I(-1):I(5)]
  • Loading branch information
nfarabullini authored Feb 26, 2024
1 parent 117de0a commit 7dea36a
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 4 deletions.
12 changes: 11 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class Dimension:
def __str__(self):
return f"{self.value}[{self.kind}]"

def __call__(self, val: int) -> NamedIndex:
return self, val


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""
Expand Down Expand Up @@ -272,7 +275,10 @@ def unit_range(r: RangeLike) -> UnitRange:
NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple
FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple
RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange
NamedSlice: TypeAlias = (
slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/
)
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice
AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement
AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex]
RelativeIndexSequence: TypeAlias = tuple[
Expand Down Expand Up @@ -307,6 +313,10 @@ def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
)


def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]:
return isinstance(obj, slice) and (is_named_index(obj.start) and is_named_index(obj.stop))


def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]:
return (
is_int_index(v)
Expand Down
29 changes: 29 additions & 0 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,32 @@ def _find_index_of_dim(
if dim == d:
return i
return None


def canonicalize_any_index_sequence(
index: common.AnyIndexSpec,
) -> common.AnyIndexSpec:
# TODO: instead of canonicalizing to `NamedRange`, we should canonicalize to `NamedSlice`
new_index: common.AnyIndexSpec = (index,) if isinstance(index, slice) else index
if isinstance(new_index, tuple) and all(isinstance(i, slice) for i in new_index):
new_index = tuple([_named_slice_to_named_range(i) for i in new_index]) # type: ignore[arg-type, assignment] # all i's are slices as per if statement
return new_index


def _named_slice_to_named_range(
idx: common.NamedSlice,
) -> common.NamedRange | common.NamedSlice:
assert hasattr(idx, "start") and hasattr(idx, "stop")
if common.is_named_slice(idx):
idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = idx.start[0], idx.start[1], idx.stop[0], idx.stop[1] # type: ignore[attr-defined]
if idx_start_0 != idx_stop_0:
raise IndexError(
f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'."
)
assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int)
return (idx_start_0, common.UnitRange(idx_start_1, idx_stop_1))
if common.is_named_index(idx.start) and idx.stop is None:
raise IndexError(f"Upper bound needs to be specified for {idx}.")
if common.is_named_index(idx.stop) and idx.start is None:
raise IndexError(f"Lower bound needs to be specified for {idx}.")
return idx
1 change: 1 addition & 0 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def __invert__(self) -> NdArrayField:
def _slice(
self, index: common.AnyIndexSpec
) -> tuple[common.Domain, common.RelativeIndexSequence]:
index = embedded_common.canonicalize_any_index_sequence(index)
new_domain = embedded_common.sub_domain(self.domain, index)

index_sequence = common.as_any_index_sequence(index)
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __call__(self, *args):
def make_node(o):
if isinstance(o, Node):
return o
if isinstance(o, common.Dimension):
return AxisLiteral(value=o.value)
if callable(o):
if o.__name__ == "<lambda>":
return lambdadef(o)
Expand All @@ -156,8 +158,6 @@ def make_node(o):
return OffsetLiteral(value=o.value)
if isinstance(o, core_defs.Scalar):
return im.literal_from_value(o)
if isinstance(o, common.Dimension):
return AxisLiteral(value=o.value)
if isinstance(o, tuple):
return _f("make_tuple", *(make_node(arg) for arg in o))
if o is None:
Expand Down
35 changes: 34 additions & 1 deletion tests/next_tests/unit_tests/embedded_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from gt4py.next import common
from gt4py.next.common import UnitRange
from gt4py.next.embedded import exceptions as embedded_exceptions
from gt4py.next.embedded.common import _slice_range, iterate_domain, sub_domain
from gt4py.next.embedded.common import (
_slice_range,
canonicalize_any_index_sequence,
iterate_domain,
sub_domain,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -147,3 +152,31 @@ def test_iterate_domain():
testee = list(iterate_domain(domain))

assert testee == ref


@pytest.mark.parametrize(
"slices, expected",
[
[slice(I(3), I(4)), ((I, common.UnitRange(3, 4)),)],
[
(slice(J(3), J(6)), slice(I(3), I(5))),
((J, common.UnitRange(3, 6)), (I, common.UnitRange(3, 5))),
],
[slice(I(1), J(7)), IndexError],
[
slice(I(1), None),
IndexError,
],
[
slice(None, K(8)),
IndexError,
],
],
)
def test_slicing(slices, expected):
if expected is IndexError:
with pytest.raises(IndexError):
canonicalize_any_index_sequence(slices)
else:
testee = canonicalize_any_index_sequence(slices)
assert testee == expected
43 changes: 43 additions & 0 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,49 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape):
assert indexed_field.domain.dims == expected_dimensions


def test_absolute_indexing_dim_sliced():
domain = common.Domain(
dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25))
)
field = common._field(np.ones((5, 10, 15)), domain=domain)
indexed_field_1 = field[JDim(8) : JDim(10), IDim(5) : IDim(9)]
expected = field[(IDim, UnitRange(5, 9)), (JDim, UnitRange(8, 10))]

assert common.is_field(indexed_field_1)
assert indexed_field_1 == expected


def test_absolute_indexing_dim_sliced_single_slice():
domain = common.Domain(
dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25))
)
field = common._field(np.ones((5, 10, 15)), domain=domain)
indexed_field_1 = field[KDim(11)]
indexed_field_2 = field[(KDim, 11)]

assert common.is_field(indexed_field_1)
assert indexed_field_1 == indexed_field_2


def test_absolute_indexing_wrong_dim_sliced():
domain = common.Domain(
dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25))
)
field = common._field(np.ones((5, 10, 15)), domain=domain)

with pytest.raises(IndexError, match="Dimensions slicing mismatch between 'JDim' and 'IDim'."):
field[JDim(8) : IDim(10)]


def test_absolute_indexing_empty_dim_sliced():
domain = common.Domain(
dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25))
)
field = common._field(np.ones((5, 10, 15)), domain=domain)
with pytest.raises(IndexError, match="Lower bound needs to be specified"):
field[: IDim(10)]


def test_absolute_indexing_value_return():
domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15)))
field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain)
Expand Down

0 comments on commit 7dea36a

Please sign in to comment.