From 1d99f906adde4fe96d49338c6dbd713aa526e0b8 Mon Sep 17 00:00:00 2001 From: Stephannie Jimenez Date: Wed, 26 Jul 2023 17:19:00 -0500 Subject: [PATCH] clarify behavior when dtype=None in sum, prod and trace --- src/array_api_stubs/_draft/linalg.py | 9 +++++---- .../_draft/statistical_functions.py | 18 ++++++++++-------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/array_api_stubs/_draft/linalg.py b/src/array_api_stubs/_draft/linalg.py index 3cec1770e..93711b13a 100644 --- a/src/array_api_stubs/_draft/linalg.py +++ b/src/array_api_stubs/_draft/linalg.py @@ -718,10 +718,11 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr data type of the returned array. If ``None``, - if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``. - - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. - - if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type. - - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. - - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). + - if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``, + - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. + - if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type. + - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. + - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``. diff --git a/src/array_api_stubs/_draft/statistical_functions.py b/src/array_api_stubs/_draft/statistical_functions.py index 93faaf31e..066da0da5 100644 --- a/src/array_api_stubs/_draft/statistical_functions.py +++ b/src/array_api_stubs/_draft/statistical_functions.py @@ -143,10 +143,11 @@ def prod( data type of the returned array. If ``None``, - if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``. - - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. - - if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type. - - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. - - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). + - if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``, + - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. + - if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type. + - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. + - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the product. Default: ``None``. @@ -240,10 +241,11 @@ def sum( data type of the returned array. If ``None``, - if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``. - - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. - - if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type. - - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. - - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). + - if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``, + - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. + - if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type. + - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. + - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.