From eba54b35b040bb08359fdca7aaeb0340ecb5c87c Mon Sep 17 00:00:00 2001 From: Athan Date: Mon, 19 Sep 2022 00:43:34 -0700 Subject: [PATCH] Add support for broadcasting to `linalg.cross` (#417) * Make explicit that broadcasting only applies to non-compute dimensions in vecdot * Add support for broadcasting to `linalg.cross` * Update copy * Update copy * Fix spacing --- spec/API_specification/array_api/linalg.py | 19 ++++++++++++++++--- .../array_api/linear_algebra_functions.py | 8 ++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index abce23437..5336d93c6 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -41,21 +41,34 @@ def cholesky(x: array, /, *, upper: bool = False) -> array: def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: """ - Returns the cross product of 3-element vectors. If ``x1`` and ``x2`` are multi-dimensional arrays (i.e., both have a rank greater than ``1``), then the cross-product of each pair of corresponding 3-element vectors is independently computed. + Returns the cross product of 3-element vectors. + + If ``x1`` and/or ``x2`` are multi-dimensional arrays (i.e., the broadcasted result has a rank greater than ``1``), then the cross-product of each pair of corresponding 3-element vectors is independently computed. Parameters ---------- x1: array first input array. Should have a real-valued data type. x2: array - second input array. Must have the same shape as ``x1``. Should have a real-valued data type. + second input array. Must be compatible with ``x1`` for all non-compute axes (see :ref:`broadcasting`). The size of the axis over which to compute the cross product must be the same size as the respective axis in ``x1``. Should have a real-valued data type. + + .. note:: + The compute axis (dimension) must not be broadcasted. + axis: int - the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. If set to ``-1``, the function computes the cross product for vectors defined by the last axis (dimension). Default: ``-1``. + the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``. Returns ------- out: array an array containing the cross products. The returned array must have a data type determined by :ref:`type-promotion`. + + + **Raises** + + - if provided an invalid ``axis``. + - if the size of the axis over which to compute the cross product is not equal to ``3``. + - if the size of the axis over which to compute the cross product is not the same (before broadcasting) for both ``x1`` and ``x2``. """ def det(x: array, /) -> array: diff --git a/spec/API_specification/array_api/linear_algebra_functions.py b/spec/API_specification/array_api/linear_algebra_functions.py index 1284c9c4e..f8f15e5b0 100644 --- a/spec/API_specification/array_api/linear_algebra_functions.py +++ b/spec/API_specification/array_api/linear_algebra_functions.py @@ -92,12 +92,12 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: x1: array first input array. Should have a real-valued data type. x2: array - second input array. Should have a real-valued data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal. + second input array. Must be compatible with ``x1`` for all non-contracted axes (see :ref:`broadcasting`). The size of the axis over which to compute the dot product must be the same size as the respective axis in ``x1``. Should have a real-valued data type. .. note:: - Contracted axes (dimensions) must not be broadcasted. + The contracted axis (dimension) must not be broadcasted. - axis:int + axis: int axis over which to compute the dot product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``. Returns @@ -109,7 +109,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: **Raises** - if provided an invalid ``axis``. - - if the size of the axis over which to compute the dot product is not the same for both ``x1`` and ``x2``. + - if the size of the axis over which to compute the dot product is not the same (before broadcasting) for both ``x1`` and ``x2``. """ __all__ = ['matmul', 'matrix_transpose', 'tensordot', 'vecdot']