diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 2fdc8e7fd5..d40a5b9d43 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -30,7 +30,11 @@ float_dtypes, lvector, ) -from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string +from pytensor.tensor.utils import ( + broadcast_static_dim_lengths, + import_func_from_string, + normalize_reduce_axis, +) from pytensor.tensor.variable import TensorVariable from pytensor.utils import uniq @@ -1371,7 +1375,6 @@ def _acc_dtype(self, idtype): def make_node(self, input): input = as_tensor_variable(input) - inp_dims = input.type.ndim inp_dtype = input.type.dtype # We need to redefine make_node so that, if self.dtype is None, @@ -1383,29 +1386,19 @@ def make_node(self, input): assert dtype is not None assert acc_dtype is not None - axis = self.axis + axis = normalize_reduce_axis(self.axis, ndim=input.type.ndim) - # scalar inputs are treated as 1D regarding axis in this `Op` - if axis is not None: - try: - axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims)) - except np.AxisError: - raise np.AxisError(axis, ndim=inp_dims) + if axis != self.axis or dtype != self.dtype or acc_dtype != self.acc_dtype: + op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype) + else: + op = self + if axis is None: + out_shape = () + else: out_shape = tuple( s for i, s in enumerate(input.type.shape) if i not in axis ) - else: - out_shape = () - - if ( - (axis is not None and any(a < 0 for a in axis)) - or dtype != self.dtype - or acc_dtype != self.acc_dtype - ): - op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype) - else: - op = self output = TensorType(dtype=dtype, shape=out_shape)() diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index d7c69135ae..b55adb0312 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -8,7 +8,6 @@ from pytensor import config, printing from pytensor import scalar as ps -from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node @@ -26,9 +25,9 @@ cast, concatenate, constant, + expand_dims, stack, switch, - zeros_like, ) from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.elemwise import ( @@ -45,14 +44,11 @@ continuous_dtypes, discrete_dtypes, int_dtypes, - integer_dtypes, tensor, uint_dtypes, ) -from pytensor.tensor.type_other import NoneConst -from pytensor.tensor.utils import as_list +from pytensor.tensor.utils import as_list, normalize_reduce_axis from pytensor.tensor.variable import ( - TensorConstant, TensorVariable, _tensor_py_operators, ) @@ -157,7 +153,7 @@ class Argmax(COp): def __init__(self, axis): if axis is not None: - axis = tuple(axis) + axis = tuple(sorted(axis)) self.axis = axis def get_params(self, node): @@ -168,7 +164,7 @@ def get_params(self, node): c_axis = np.int64(-1) return self.params_type.get_params(c_axis=c_axis) - def make_node(self, x, axis=None): + def make_node(self, x): x = as_tensor_variable(x) if self.axis is None: all_axes = list(range(x.ndim)) @@ -198,7 +194,9 @@ def perform(self, node, inp, outs): # Work around keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") # Not-reduced axes in front - transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) + transposed_x = np.transpose( + x, np.concatenate((keep_axes, np.asarray(axes, dtype="int64"))) + ) kept_shape = transposed_x.shape[: len(keep_axes)] reduced_shape = transposed_x.shape[len(keep_axes) :] new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64")) @@ -214,7 +212,7 @@ def c_code(self, node, name, inp, out, sub): if self.axis is None: axis_code = "axis = NPY_MAXDIMS;" else: - if len(self.axis) > 1: + if len(self.axis) != 1: raise NotImplementedError() # params is only used here for now axis_code = """ @@ -253,7 +251,7 @@ def c_code(self, node, name, inp, out, sub): return ret % locals() def c_code_cache_version(self): - return (1,) + return (2,) def infer_shape(self, fgraph, node, shapes): (ishape,) = shapes @@ -277,6 +275,40 @@ def grad(self, inp, grads): return [x.zeros_like()] +def argmax(x: TensorLike, axis=None, keepdims: bool = False): + """ + Returns indices of maximum elements obtained by iterating over given axis. + + When axis is None (the default value), the argmax is performed + over the flattened tensor. + + Parameters + ---------- + x: TensorLike + Array on which to compute argmax + axis: + Axis along which to compute argmax. Unlike numpy multiple partial axis are supported. + keepdims : bool + If this is set to True, the axes which are reduced are left in + the result as dimensions with size one. With this option, the result + will broadcast correctly against the original tensor. + + Returns + ------- + TensorVariable + TensorVariable representing the argmax operation + + """ + x = as_tensor_variable(x) + axis = normalize_reduce_axis(axis, ndim=x.type.ndim) + out = Argmax(axis)(x) + + if keepdims: + out = makeKeepDims(x, out, axis) + + return out + + @_vectorize_node.register(Argmax) def vectorize_argmax_node(op, node, batch_x): core_ndim = node.inputs[0].type.ndim @@ -297,85 +329,9 @@ def makeKeepDims(x, y, axis): """ x = as_tensor_variable(x) - y = as_tensor_variable(y) - if axis is None: axis = list(range(x.type.ndim)) - elif isinstance(axis, int | np.integer): - axis = [axis] - elif isinstance(axis, np.ndarray) and axis.ndim == 0: - axis = [int(axis)] - else: - axis = [int(a) for a in axis] - newaxis = [] - for a in axis: - if not isinstance(a, int): - raise ValueError("keepdims option can be used only with constant axis") - if a < 0: - a += x.type.ndim - newaxis.append(a) - i = 0 - new_dims = [] - for j, _ in enumerate(x.type.broadcastable): - if j in newaxis: - new_dims.append("x") - else: - new_dims.append(i) - i += 1 - return DimShuffle(y.type.broadcastable, new_dims)(y) - - -def check_and_normalize_axes(x, axis): - """Check axes, normalize and convert them to a Python list of integers. - - Parameters - ---------- - x: TensorVariable - axis: int, tuple or list of integers - - Returns - ------- - axis: list of integers - Return an empty list if argument is None. - - """ - x = as_tensor_variable(x) - if axis is None: - axis = [] - elif isinstance(axis, int | np.integer) or ( - isinstance(axis, np.ndarray) and axis.ndim == 0 - ): - axis = [int(axis)] - elif isinstance(axis, tuple | list | np.ndarray): - axis = [int(i) for i in axis] - elif isinstance(axis, Variable): - if NoneConst.equals(axis): - axis = [] - elif not isinstance(axis, TensorConstant): - raise TypeError(f"Computation needs a constant axis. Got {axis}") - else: - assert axis.dtype in integer_dtypes - if isinstance(axis.data, int | np.integer) or ( - isinstance(axis.data, np.ndarray) and axis.data.ndim == 0 - ): - axis = [int(axis.data)] - elif isinstance(axis.data, list | np.ndarray): - axis = [int(i) for i in axis.data] - else: - raise TypeError( - f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got {axis}" - ) - if len(axis) > 0: - for i in range(len(axis)): - if axis[i] < 0: - axis[i] += x.type.ndim - if axis[i] < 0 or axis[i] >= x.type.ndim: - raise ValueError( - f"Computation needs a valid axis number for {int(x.type.ndim)}-D tensor. Got {int(axis[i])}" - ) - axis = list(set(axis)) - axis.sort() - return axis + return expand_dims(y, axis) def max_and_argmax(a, axis=None, keepdims=False): @@ -396,28 +352,10 @@ def max_and_argmax(a, axis=None, keepdims=False): """ # Check axis and convert it to a Python list of integers. # Axis will be used as an op param of Max and Argmax. - a = as_tensor_variable(a) - - is_axis_empty = False - if axis == (): - is_axis_empty = True - - axis = check_and_normalize_axes(a, axis) - - if len(axis) == 0 and not is_axis_empty: - axis = None - - out = Max(axis)(a) - - if not is_axis_empty: - argout = Argmax(axis)(a) - else: - argout = zeros_like(a, dtype="int64") - - if keepdims: - out = makeKeepDims(a, out, axis) - argout = makeKeepDims(a, argout, axis) - return [out, argout] + return [ + max(a, axis=axis, keepdims=keepdims), + argmax(a, axis=axis, keepdims=keepdims), + ] class FixedOpCAReduce(CAReduce): @@ -466,7 +404,7 @@ def clone(self, **kwargs): axis = kwargs.get("axis", self.axis) return type(self)(axis=axis) - def grad(self, inp, grads): + def L_op(self, inputs, outputs, grads): # The strict sense mathematical gradient of the maximum function is # not calculated here for it is not defined at every point where some # coordinates are identical. However, since the latter set has null @@ -480,53 +418,27 @@ def grad(self, inp, grads): # g_max has one less dimension than x, so you need to complete # g_max to x's shape when axis=0 the broadcasting mechanism # does it automatically - x = inp[0] - if self.axis is None: - self.axis = tuple(range(x.ndim)) - axis = as_tensor_variable(self.axis) - (g_max,) = grads - - g_max_disconnected = isinstance(g_max.type, DisconnectedType) - - # if the op is totally disconnected, so are its inputs - if g_max_disconnected: - return [DisconnectedType()()] + [x] = inputs + [out] = outputs + [g_out] = grads - # if NoneConst.equals(axis): - if axis is None: - axis_ = list(range(x.ndim)) - else: - axis_ = axis - xmax = max(x, axis_) - - # Raise the g_max and xmax to the same number of dim as the input. - pattern = [] - out_dim = 0 - if NoneConst.equals(axis): - # We are taking the max/argmax over all dimensions. - axis = None - for i in range(x.ndim): - if axis is None or i in axis.data: - pattern.append("x") - else: - pattern.append(out_dim) - out_dim += 1 - g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max) - xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax) + axis = tuple(range(x.ndim)) if self.axis is None else self.axis + out_pad = expand_dims(out, axis) + g_out_pad = expand_dims(g_out, axis) # Set the grad to the correct position. - g_x = eq(xmax_pad, x) * g_max_pad + g_x = eq(out_pad, x) * g_out_pad return (g_x,) def R_op(self, inputs, eval_points): if eval_points[0] is None: return [None, None] if len(self.axis) != 1: - raise ValueError("R_op supported for arg_max only for one axis!") + raise ValueError("R_op supported for max only for one axis!") if self.axis[0] > 1: - raise ValueError("R_op supported for arg_max only when axis is 0 or 1") + raise ValueError("R_op supported for max only when axis is 0 or 1") if inputs[0].ndim != 2: - raise ValueError("R_op supported for arg_max only when input is a matrix") + raise ValueError("R_op supported for max only when input is a matrix") max_pos = Argmax(self.axis).make_node(*inputs).outputs # print(eval_points[0].eval()) if self.axis[0] == 0: @@ -565,35 +477,13 @@ def max(x, axis=None, keepdims=False): We return an error as numpy when we reduce a dim with a shape of 0. """ - out = max_and_argmax(x, axis)[0] + out = Max(axis=axis)(x) if keepdims: out = makeKeepDims(x, out, axis) return out -def argmax(x, axis=None, keepdims=False): - """ - Returns indices of maximum elements obtained by iterating over given axis. - - When axis is None (the default value), the argmax is performed - over the flattened tensor. - - Parameters - ---------- - keepdims : bool - If this is set to True, the axes which are reduced are left in - the result as dimensions with size one. With this option, the result - will broadcast correctly against the original tensor. - - """ - argout = max_and_argmax(x, axis)[1] - - if keepdims: - argout = makeKeepDims(x, argout, axis) - return argout - - def min(x, axis=None, keepdims=False): """ Returns minimum elements obtained by iterating over given axis. diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 6e34c27d43..1e7d16a612 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -509,8 +509,6 @@ def svd_uv_merge(fgraph, node): # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. # First, iterate to see if there is an SVD Op that can be reused. for cl, _ in fgraph.clients[x]: - if cl == "output": - continue if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): if not cl.op.core_op.compute_uv: return { @@ -529,8 +527,6 @@ def svd_uv_merge(fgraph, node): # We want rewrite if there is another one with compute_uv=True. # For this case, just reuse the `s` from the one with compute_uv=True. for cl, _ in fgraph.clients[x]: - if cl == "output": - continue if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): if cl.op.core_op.compute_uv and ( len(fgraph.clients[cl.outputs[0]]) > 0 diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index b8ae1e780b..60ae8ebed8 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -1,7 +1,9 @@ import re from collections.abc import Sequence +from typing import cast import numpy as np +from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.utils import hash_from_code @@ -223,3 +225,19 @@ def operand_sig(operand_ndim: int, prefix: str) -> str: operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim) ) return f"{inputs_sig}->{outputs_sig}" + + +def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None: + """Normalize the axis parameter for reduce operations.""" + if axis is None: + return None + + # scalar inputs are treated as 1D regarding axis in reduce operations + if axis is not None: + try: + axis = normalize_axis_tuple(axis, ndim=max(1, ndim)) + except np.AxisError: + raise np.AxisError(axis, ndim=ndim) + + # TODO: If axis tuple is equivalent to None, return None for more canonicalization? + return cast(tuple, axis) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index b66599e3ca..e86bd4ec17 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -154,7 +154,6 @@ vectors, zvector, ) -from pytensor.tensor.type_other import NoneConst from tests import unittest_tools as utt from tests.link.test_link import make_function from tests.tensor.utils import ( @@ -767,9 +766,10 @@ def setup_method(self): Max.debug = 0 Argmax.debug = 0 - def test_basic(self): + @pytest.mark.parametrize("empty_axis", [(), None]) + def test_empty_axis_scalar(self, empty_axis): n = as_tensor_variable(5) - v, i = eval_outputs(max_and_argmax(n, axis=())) + v, i = eval_outputs(max_and_argmax(n, axis=empty_axis)) assert v == 5.0 assert i == 0 assert i.dtype == "int64" @@ -778,6 +778,29 @@ def test_basic(self): v = eval_outputs(max_and_argmax(n)[1].shape) assert len(v) == 0 + def test_empty_axis_tensor(self): + x = np.random.normal(size=(2, 3, 5, 7)) + axis = () + + non_axis = tuple(i for i in range(x.ndim) if i not in axis) + shape_axis = tuple(x.shape[dim] for dim in axis) + shape_non_axis = tuple(x.shape[dim] for dim in non_axis) + x_transposed = x.transpose(*axis, *non_axis) + + x_axis_raveled = x_transposed.reshape( + np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int) + ) + max_x = max_and_argmax(x, axis=axis)[0].eval() + argmax_x = max_and_argmax(x, axis=axis)[1].eval() + + raveled_max = x_axis_raveled[ + argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int)) + ] + indirect_max = raveled_max.reshape(shape_non_axis) + + np.testing.assert_allclose(max_x, x.max(axis=axis)) + np.testing.assert_allclose(indirect_max, x.max(axis=axis)) + def test_basic_1(self): n = as_tensor_variable([1, 2, 3, 2, -6]) v, i = eval_outputs(max_and_argmax(n)) @@ -796,8 +819,6 @@ def test_basic_1(self): (None, None), ([0, 1], None), ([1, 0], None), - (NoneConst.clone(), None), - (constant(0), 0), ], ) def test_basic_2(self, axis, np_axis): @@ -826,8 +847,6 @@ def test_basic_2(self, axis, np_axis): (None, None), ([0, 1], None), ([1, 0], None), - (NoneConst.clone(), None), - (constant(0), 0), ], ) def test_basic_2_float16(self, axis, np_axis): @@ -986,7 +1005,7 @@ def check_grad_max(data, max_grad_data, axis=None): safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) # Test grad with multiple axes - for i in [[0, 1], [0, 0]]: + for i in [[0, 1], [0, 2, 3]]: safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[0], [data]) safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[1], [data]) @@ -1043,29 +1062,6 @@ def test_vectorize(self, core_axis, batch_axis): assert isinstance(new_node.op, Argmax) assert new_node.op.axis == batch_axis - def test_max_empty_axis(self): - x = np.random.normal(size=(2, 3, 5, 7)) - axis = () - - non_axis = tuple(i for i in range(x.ndim) if i not in axis) - shape_axis = tuple(x.shape[dim] for dim in axis) - shape_non_axis = tuple(x.shape[dim] for dim in non_axis) - x_transposed = x.transpose(*axis, *non_axis) - - x_axis_raveled = x_transposed.reshape( - np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int) - ) - max_x = max_and_argmax(x, axis=axis)[0].eval() - argmax_x = max_and_argmax(x, axis=axis)[1].eval() - - raveled_max = x_axis_raveled[ - argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int)) - ] - indirect_max = raveled_max.reshape(shape_non_axis) - - np.testing.assert_allclose(max_x, x.max(axis=axis)) - np.testing.assert_allclose(indirect_max, x.max(axis=axis)) - class TestArgminArgmax: def setup_method(self): diff --git a/tests/test_rop.py b/tests/test_rop.py index d8fc78a51b..0b9fe41a1e 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -192,9 +192,7 @@ def check_rop_lop(self, y, out_shape): class TestRopLop(RopLopChecker): def test_max(self): - # If we call max directly, we will return an CAReduce object - # which doesn't have R_op implemented! - # self.check_mat_rop_lop(at_max(self.mx, axis=[0,1])[0], ()) + # self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ()) self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],)) self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],))