From bd1e695db7cbfc11b1d31bc4dd33171d7a83f035 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 13 Jul 2023 09:48:24 -0700 Subject: [PATCH 1/4] Add specification for moving array axes to new positions --- .../manipulation_functions.rst | 1 + .../_draft/manipulation_functions.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/spec/draft/API_specification/manipulation_functions.rst b/spec/draft/API_specification/manipulation_functions.rst index fc0e752b9..7eb7fa8b0 100644 --- a/spec/draft/API_specification/manipulation_functions.rst +++ b/spec/draft/API_specification/manipulation_functions.rst @@ -23,6 +23,7 @@ Objects in API concat expand_dims flip + moveaxis permute_dims reshape roll diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index d62cd4f07..f71a75517 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -99,6 +99,26 @@ def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> """ +def moveaxis(x: array, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]], /) -> array: + """ + Moves array axes (dimensions) to new positions, while leaving other axes in their original positions. + + Parameters + ---------- + x: array + input array. + source: Union[int, Tuple[int, ...]] + Axes to move. Provided axes must be unique. If ``x`` has rank (i.e, number of dimensions) ``N``, a valid axis must reside on the open-interval ``(-N, N)``. + destination: Union[int, Tuple[int, ...]] + indices defining the desired positions for each respective ``source`` axis index. Provided indices must be unique. If ``x`` has rank (i.e, number of dimensions) ``N``, a valid axis must reside on the open-interval ``(-N, N)``. + + Returns + ------- + out: array + an array containing reordered axes. The returned array must have the same data type as ``x``. + """ + + def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: """ Permutes the axes (dimensions) of an array ``x``. @@ -240,6 +260,7 @@ def unstack(x: array, /, *, axis: int = 0) -> Tuple[array, ...]: "concat", "expand_dims", "flip", + "moveaxis", "permute_dims", "reshape", "roll", From 149d340c77e3166a9e60a7ff75621769818780ca Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 19 Sep 2023 14:38:12 -0700 Subject: [PATCH 2/4] Fix lint error --- src/array_api_stubs/_draft/manipulation_functions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 2b6689647..4c2a4ab4a 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -115,7 +115,12 @@ def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> """ -def moveaxis(x: array, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]], /) -> array: +def moveaxis( + x: array, + source: Union[int, Tuple[int, ...]], + destination: Union[int, Tuple[int, ...]], + / +) -> array: """ Moves array axes (dimensions) to new positions, while leaving other axes in their original positions. From f05c22fb42bd3c69b1589ef9c7162ebf33a12d09 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 19 Sep 2023 14:44:54 -0700 Subject: [PATCH 3/4] Fix index range --- src/array_api_stubs/_draft/manipulation_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 4c2a4ab4a..4bfeb1caa 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -129,9 +129,9 @@ def moveaxis( x: array input array. source: Union[int, Tuple[int, ...]] - Axes to move. Provided axes must be unique. If ``x`` has rank (i.e, number of dimensions) ``N``, a valid axis must reside on the open-interval ``(-N, N)``. + Axes to move. Provided axes must be unique. If ``x`` has rank (i.e, number of dimensions) ``N``, a valid axis must reside on the half-open interval ``[-N, N)``. destination: Union[int, Tuple[int, ...]] - indices defining the desired positions for each respective ``source`` axis index. Provided indices must be unique. If ``x`` has rank (i.e, number of dimensions) ``N``, a valid axis must reside on the open-interval ``(-N, N)``. + indices defining the desired positions for each respective ``source`` axis index. Provided indices must be unique. If ``x`` has rank (i.e, number of dimensions) ``N``, a valid axis must reside on the half-open interval ``[-N, N)``. Returns ------- From 8149b33b51bf2d7064bd72731d99ccc31e9cc265 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 19 Sep 2023 14:45:49 -0700 Subject: [PATCH 4/4] Fix lint error --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 4bfeb1caa..2bc929134 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -119,7 +119,7 @@ def moveaxis( x: array, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]], - / + /, ) -> array: """ Moves array axes (dimensions) to new positions, while leaving other axes in their original positions.