Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: new domain slice syntax #1453

Merged
merged 25 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
361bdcf
edits for indexing
nfarabullini Feb 14, 2024
075a4df
added property in dimension class
nfarabullini Feb 14, 2024
9b70c0c
edit to dim class
nfarabullini Feb 14, 2024
c028b65
edit to dim class
nfarabullini Feb 14, 2024
b4dc92c
edit to dim class
nfarabullini Feb 14, 2024
e1dacd7
added assertion and fixed typing
nfarabullini Feb 14, 2024
79148a0
edit to if condition
nfarabullini Feb 14, 2024
df5b22f
Update src/gt4py/next/common.py
nfarabullini Feb 15, 2024
204fcb5
Update src/gt4py/next/embedded/nd_array_field.py
nfarabullini Feb 15, 2024
4060fa5
Update src/gt4py/next/embedded/nd_array_field.py
nfarabullini Feb 15, 2024
39df807
Update tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
nfarabullini Feb 15, 2024
92f8a17
edits following review
nfarabullini Feb 15, 2024
a826218
edits following review
nfarabullini Feb 15, 2024
b429fdf
edits following discussion with Hannes
nfarabullini Feb 16, 2024
689fd47
edits following review comment
nfarabullini Feb 16, 2024
e21e582
edits following review comment
nfarabullini Feb 16, 2024
b61a62a
ran pre-commit
nfarabullini Feb 16, 2024
c02a525
Update src/gt4py/next/embedded/common.py
nfarabullini Feb 19, 2024
e0da304
edits following review
nfarabullini Feb 19, 2024
acfc675
removed one line
nfarabullini Feb 19, 2024
c25f955
Update src/gt4py/next/common.py
nfarabullini Feb 23, 2024
4116290
Update src/gt4py/next/embedded/common.py
nfarabullini Feb 23, 2024
e2f8761
Update src/gt4py/next/embedded/common.py
nfarabullini Feb 23, 2024
f51aca9
edits following review
nfarabullini Feb 23, 2024
5f529a4
fixes
nfarabullini Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class Dimension:
def __str__(self):
return f"{self.value}[{self.kind}]"

def __call__(self, val: int) -> NamedIndex:
return self, val


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""
Expand Down Expand Up @@ -272,7 +275,10 @@ def unit_range(r: RangeLike) -> UnitRange:
NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple
FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple
RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange
NamedSlice: TypeAlias = (
slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/
)
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice
AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement
AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex]
RelativeIndexSequence: TypeAlias = tuple[
Expand Down Expand Up @@ -307,6 +313,10 @@ def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
)


def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]:
return isinstance(obj, slice) and (is_named_index(obj.start) and is_named_index(obj.stop))


def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]:
return (
is_int_index(v)
Expand Down
29 changes: 29 additions & 0 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,32 @@ def _find_index_of_dim(
if dim == d:
return i
return None


def canonicalize_any_index_sequence(
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
index: common.AnyIndexSpec,
) -> common.AnyIndexSpec:
# TODO: instead of canonicalizing to `NamedRange`, we should canonicalize to `NamedSlice`
new_index: common.AnyIndexSpec = (index,) if isinstance(index, slice) else index
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(new_index, tuple) and all(isinstance(i, slice) for i in new_index):
new_index = tuple([_named_slice_to_named_range(i) for i in new_index]) # type: ignore[arg-type, assignment] # all i's are slices as per if statement
return new_index


def _named_slice_to_named_range(
idx: common.NamedSlice,
) -> common.NamedRange | common.NamedSlice:
assert hasattr(idx, "start") and hasattr(idx, "stop")
if common.is_named_slice(idx):
idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = idx.start[0], idx.start[1], idx.stop[0], idx.stop[1] # type: ignore[attr-defined]
if idx_start_0 != idx_stop_0:
raise IndexError(
f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'."
)
assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int)
return (idx_start_0, common.UnitRange(idx_start_1, idx_stop_1))
if common.is_named_index(idx.start) and idx.stop is None:
raise IndexError(f"Upper bound needs to be specified for {idx}.")
if common.is_named_index(idx.stop) and idx.start is None:
raise IndexError(f"Lower bound needs to be specified for {idx}.")
return idx
1 change: 1 addition & 0 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def __invert__(self) -> NdArrayField:
def _slice(
self, index: common.AnyIndexSpec
) -> tuple[common.Domain, common.RelativeIndexSequence]:
index = embedded_common.canonicalize_any_index_sequence(index)
new_domain = embedded_common.sub_domain(self.domain, index)

index_sequence = common.as_any_index_sequence(index)
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __call__(self, *args):
def make_node(o):
if isinstance(o, Node):
return o
if isinstance(o, common.Dimension):
return AxisLiteral(value=o.value)
if callable(o):
if o.__name__ == "<lambda>":
return lambdadef(o)
Expand All @@ -156,8 +158,6 @@ def make_node(o):
return OffsetLiteral(value=o.value)
if isinstance(o, core_defs.Scalar):
return im.literal_from_value(o)
if isinstance(o, common.Dimension):
return AxisLiteral(value=o.value)
if isinstance(o, tuple):
return _f("make_tuple", *(make_node(arg) for arg in o))
if o is None:
Expand Down
35 changes: 34 additions & 1 deletion tests/next_tests/unit_tests/embedded_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from gt4py.next import common
from gt4py.next.common import UnitRange
from gt4py.next.embedded import exceptions as embedded_exceptions
from gt4py.next.embedded.common import _slice_range, iterate_domain, sub_domain
from gt4py.next.embedded.common import (
_slice_range,
canonicalize_any_index_sequence,
iterate_domain,
sub_domain,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -147,3 +152,31 @@ def test_iterate_domain():
testee = list(iterate_domain(domain))

assert testee == ref


@pytest.mark.parametrize(
"slices, expected",
[
[slice(I(3), I(4)), ((I, common.UnitRange(3, 4)),)],
[
(slice(J(3), J(6)), slice(I(3), I(5))),
((J, common.UnitRange(3, 6)), (I, common.UnitRange(3, 5))),
],
[slice(I(1), J(7)), IndexError],
[
slice(I(1), None),
IndexError,
],
[
slice(None, K(8)),
IndexError,
],
],
)
def test_slicing(slices, expected):
if expected is IndexError:
with pytest.raises(IndexError):
canonicalize_any_index_sequence(slices)
else:
testee = canonicalize_any_index_sequence(slices)
assert testee == expected
43 changes: 43 additions & 0 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,49 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape):
assert indexed_field.domain.dims == expected_dimensions


def test_absolute_indexing_dim_sliced():
domain = common.Domain(
dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25))
)
field = common._field(np.ones((5, 10, 15)), domain=domain)
indexed_field_1 = field[JDim(8) : JDim(10), IDim(5) : IDim(9)]
expected = field[(IDim, UnitRange(5, 9)), (JDim, UnitRange(8, 10))]

assert common.is_field(indexed_field_1)
assert indexed_field_1 == expected


def test_absolute_indexing_dim_sliced_single_slice():
domain = common.Domain(
dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25))
)
field = common._field(np.ones((5, 10, 15)), domain=domain)
indexed_field_1 = field[KDim(11)]
indexed_field_2 = field[(KDim, 11)]

assert common.is_field(indexed_field_1)
assert indexed_field_1 == indexed_field_2


def test_absolute_indexing_wrong_dim_sliced():
domain = common.Domain(
dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25))
)
field = common._field(np.ones((5, 10, 15)), domain=domain)

with pytest.raises(IndexError, match="Dimensions slicing mismatch between 'JDim' and 'IDim'."):
field[JDim(8) : IDim(10)]


def test_absolute_indexing_empty_dim_sliced():
domain = common.Domain(
dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25))
)
field = common._field(np.ones((5, 10, 15)), domain=domain)
with pytest.raises(IndexError, match="Lower bound needs to be specified"):
field[: IDim(10)]


def test_absolute_indexing_value_return():
domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15)))
field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain)
Expand Down
Loading