Skip to content

Commit

Permalink
Add support for broadcasting to linalg.cross (#417)
Browse files Browse the repository at this point in the history
* Make explicit that broadcasting only applies to non-compute dimensions in vecdot

* Add support for broadcasting to `linalg.cross`

* Update copy

* Update copy

* Fix spacing
  • Loading branch information
kgryte authored Sep 19, 2022
1 parent 55b8fb0 commit eba54b3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
19 changes: 16 additions & 3 deletions spec/API_specification/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,34 @@ def cholesky(x: array, /, *, upper: bool = False) -> array:

def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
"""
Returns the cross product of 3-element vectors. If ``x1`` and ``x2`` are multi-dimensional arrays (i.e., both have a rank greater than ``1``), then the cross-product of each pair of corresponding 3-element vectors is independently computed.
Returns the cross product of 3-element vectors.
If ``x1`` and/or ``x2`` are multi-dimensional arrays (i.e., the broadcasted result has a rank greater than ``1``), then the cross-product of each pair of corresponding 3-element vectors is independently computed.
Parameters
----------
x1: array
first input array. Should have a real-valued data type.
x2: array
second input array. Must have the same shape as ``x1``. Should have a real-valued data type.
second input array. Must be compatible with ``x1`` for all non-compute axes (see :ref:`broadcasting`). The size of the axis over which to compute the cross product must be the same size as the respective axis in ``x1``. Should have a real-valued data type.
.. note::
The compute axis (dimension) must not be broadcasted.
axis: int
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. If set to ``-1``, the function computes the cross product for vectors defined by the last axis (dimension). Default: ``-1``.
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``.
Returns
-------
out: array
an array containing the cross products. The returned array must have a data type determined by :ref:`type-promotion`.
**Raises**
- if provided an invalid ``axis``.
- if the size of the axis over which to compute the cross product is not equal to ``3``.
- if the size of the axis over which to compute the cross product is not the same (before broadcasting) for both ``x1`` and ``x2``.
"""

def det(x: array, /) -> array:
Expand Down
8 changes: 4 additions & 4 deletions spec/API_specification/array_api/linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1: array
first input array. Should have a real-valued data type.
x2: array
second input array. Should have a real-valued data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal.
second input array. Must be compatible with ``x1`` for all non-contracted axes (see :ref:`broadcasting`). The size of the axis over which to compute the dot product must be the same size as the respective axis in ``x1``. Should have a real-valued data type.
.. note::
Contracted axes (dimensions) must not be broadcasted.
The contracted axis (dimension) must not be broadcasted.
axis:int
axis: int
axis over which to compute the dot product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``.
Returns
Expand All @@ -109,7 +109,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
**Raises**
- if provided an invalid ``axis``.
- if the size of the axis over which to compute the dot product is not the same for both ``x1`` and ``x2``.
- if the size of the axis over which to compute the dot product is not the same (before broadcasting) for both ``x1`` and ``x2``.
"""

__all__ = ['matmul', 'matrix_transpose', 'tensordot', 'vecdot']

0 comments on commit eba54b3

Please sign in to comment.