Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement flip #528

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ jobs:
# not implemented
array_api_tests/test_array_object.py::test_setitem
array_api_tests/test_array_object.py::test_setitem_masking
array_api_tests/test_manipulation_functions.py::test_flip
array_api_tests/test_sorting_functions.py
array_api_tests/test_statistical_functions.py::test_std
array_api_tests/test_statistical_functions.py::test_var
Expand Down
2 changes: 1 addition & 1 deletion api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `broadcast_to` | :white_check_mark: | | |
| | `concat` | :white_check_mark: | | |
| | `expand_dims` | :white_check_mark: | | |
| | `flip` | :x: | | Needs indexing with step=-1, [#114](https://github.com/cubed-dev/cubed/issues/114) |
| | `flip` | :white_check_mark: | | |
| | `permute_dims` | :white_check_mark: | | |
| | `repeat` | :x: | 2023.12 | |
| | `reshape` | :white_check_mark: | | Partial implementation |
Expand Down
2 changes: 2 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
broadcast_to,
concat,
expand_dims,
flip,
moveaxis,
permute_dims,
reshape,
Expand All @@ -292,6 +293,7 @@
"broadcast_to",
"concat",
"expand_dims",
"flip",
"moveaxis",
"permute_dims",
"reshape",
Expand Down
2 changes: 2 additions & 0 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@
broadcast_to,
concat,
expand_dims,
flip,
moveaxis,
permute_dims,
reshape,
Expand All @@ -234,6 +235,7 @@
"broadcast_to",
"concat",
"expand_dims",
"flip",
"moveaxis",
"permute_dims",
"reshape",
Expand Down
44 changes: 44 additions & 0 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,50 @@ def flatten(x):
return reshape(x, (-1,))


def flip(x, /, *, axis=None):
if axis is None:
axis = tuple(range(x.ndim)) # all axes
if not isinstance(axis, tuple):
axis = (axis,)
axis = validate_axis(axis, x.ndim)
return map_direct(
_flip,
x,
shape=x.shape,
dtype=x.dtype,
chunks=x.chunks,
extra_projected_mem=x.chunkmem,
target_chunks=x.chunks,
axis=axis,
)


def _flip(x, *arrays, target_chunks=None, axis=None, block_id=None):
array = arrays[0].zarray # underlying Zarr array (or virtual array)
chunks = target_chunks

# produce a key that has slices (except for axis dimensions, which are replaced below)
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
key = list(get_item(chunks, idx))

for ax in axis:
# determine the start and stop indexes for this block along the axis dimension
chunksize = to_chunksize(chunks)
start = block_id[ax] * chunksize[ax]
stop = start + x.shape[ax]

# flip start and stop
axis_len = array.shape[ax]
start, stop = axis_len - stop, axis_len - start

# replace with slice
key[ax] = slice(start, stop)

key = tuple(key)

return nxp.flip(array[key], axis=axis)


def moveaxis(
x,
source,
Expand Down
25 changes: 25 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,31 @@ def test_expand_dims(spec, executor):
assert_array_equal(b.compute(executor=executor), np.expand_dims([1, 2, 3], 0))


@pytest.mark.parametrize(
"shape, chunks, axis",
[
((10,), (4,), None),
((10,), (4,), 0),
((10, 7), (4, 3), None),
((10, 7), (4, 3), 0),
((10, 7), (4, 3), 1),
((10, 7), (4, 3), (0, 1)),
((10, 7), (4, 3), -1),
],
)
def test_flip(executor, shape, chunks, axis):
x = np.random.randint(10, size=shape)
a = xp.asarray(x, chunks=chunks)
b = xp.flip(a, axis=axis)

assert b.chunks == a.chunks

assert_array_equal(
b.compute(executor=executor),
np.flip(x, axis=axis),
)


def test_moveaxis(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.moveaxis(a, [0, -1], [-1, 0])
Expand Down
Loading