diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index f48b3b9a59..bde6b2c6a4 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1,5 +1,7 @@ +from collections.abc import Sequence from copy import copy from textwrap import dedent +from typing import Literal import numpy as np from numpy.core.numeric import normalize_axis_tuple @@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp): Parameters ---------- - input_broadcastable - The expected broadcastable pattern of the input + input_ndim + The expected number of dimension of the input new_order A list representing the relationship between the input's dimensions and the output's dimensions. Each element of the list can either be an index or 'x'. Indices must be encoded as python integers, not pytensor symbolic integers. - inplace : bool, optional - If True (default), the output will be a view of the input. + Missing indexes correspond to drop dimensions. Notes ----- @@ -77,10 +78,10 @@ class DimShuffle(ExternalCOp): .. code-block:: python - DimShuffle((False, False, False), ['x', 2, 'x', 0, 1]) + DimShuffle(input_ndim=3, new_order=['x', 2, 'x', 0, 1]) - This `Op` will only work on 3d tensors with no broadcastable - dimensions. The first dimension will be broadcastable, + This `Op` will only work on 3d tensors. + The first dimension of the output will be broadcastable, then we will have the third dimension of the input tensor as the second of the resulting tensor, etc. If the tensor has shape (20, 30, 40), the resulting tensor will have dimensions @@ -88,39 +89,34 @@ class DimShuffle(ExternalCOp): .. code-block:: python - DimShuffle((True, False), [1]) + DimShuffle(input_ndim=2, new_order=[1]) - This `Op` will only work on 2d tensors with the first dimension - broadcastable. - The second dimension of the input tensor will be the first dimension of - the resulting tensor. - If the tensor has shape (1, 20), the resulting tensor will have shape - (20, ). + This `Op` will only work on 2d tensors with the first dimension broadcastable. + The second dimension of the input tensor will be the first dimension of the resulting tensor. + If the tensor has shape (1, 20), the resulting tensor will have shape (20, ). Examples -------- .. code-block:: python - DimShuffle((), ['x']) # make a 0d (scalar) into a 1d vector - DimShuffle((False, False), [0, 1]) # identity - DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions - DimShuffle((False,), ['x', 0]) # make a row out of a 1d vector - # (N to 1xN) - DimShuffle((False,), [0, 'x']) # make a column out of a 1d vector - # (N to Nx1) - DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB - DimShuffle((False, False), [0, 'x', 1]) # AxB to Ax1xB - DimShuffle((False, False), [1, 'x', 0]) # AxB to Bx1xA - - The reordering of the dimensions can be done with the numpy.transpose - function. - Adding, subtracting dimensions can be done with reshape. + DimShuffle(input_ndim=0, new_order=['x']) # make a 0d (scalar) into a 1d vector + DimShuffle(input_ndim=2, new_order=[0, 1]) # identity + DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition + DimShuffle(input_ndim=1, new_order=['x', 0]) # make a row out of a 1d vector (N to 1xN) + DimShuffle(input_ndim=1, new_order=[0, 'x']) # make a column out of a 1d vector (N to Nx1) + DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB + DimShuffle(input_ndim=2, new_order=[0, 'x', 1]) # AxB to Ax1xB + DimShuffle(input_ndim=2, new_order=[1, 'x', 0]) # AxB to Bx1xA + Notes + ----- + The python implementation of this Op combines numpy.transpose for reordering of the dimensions + and numpy.reshape for subtracting and adding broadcastable dimensions. """ _f16_ok = True check_input = False - __props__ = ("input_broadcastable", "new_order", "inplace") + __props__ = ("input_ndim", "new_order", "inplace") c_func_file = "c_code/dimshuffle.c" c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)" @@ -133,16 +129,11 @@ def params_type(self): inplace=scalar_bool, ) - def __init__(self, input_broadcastable, new_order): + def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): super().__init__([self.c_func_file], self.c_func_name) - self.input_broadcastable = tuple(input_broadcastable) - if not all(isinstance(bs, bool | np.bool_) for bs in self.input_broadcastable): - raise ValueError( - f"input_broadcastable must be boolean, {self.input_broadcastable}" - ) + self.input_ndim = input_ndim self.new_order = tuple(new_order) - self.inplace = True for i, j in enumerate(new_order): @@ -152,10 +143,10 @@ def __init__(self, input_broadcastable, new_order): "DimShuffle indices must be Python ints; got " f"{j} of type {type(j)}." ) - if j >= len(input_broadcastable): + if j >= input_ndim: raise ValueError( f"new_order[{i}] is {j}, but the input only has " - f"{len(input_broadcastable)} axes." + f"{input_ndim} axes." ) if j in new_order[(i + 1) :]: raise ValueError( @@ -164,19 +155,7 @@ def __init__(self, input_broadcastable, new_order): ) # List of input dimensions to drop - drop = [] - for i, b in enumerate(input_broadcastable): - if i not in new_order: - # We want to drop this dimension because it's not a value in - # `new_order` - if b == 1: - drop.append(i) - else: - # We cannot drop non-broadcastable dimensions - raise ValueError( - "Cannot drop a non-broadcastable dimension: " - f"{input_broadcastable}, {new_order}" - ) + drop = [i for i in range(input_ndim) if i not in new_order] # This is the list of the original dimensions that we keep self.shuffle = [x for x in new_order if x != "x"] @@ -186,7 +165,6 @@ def __init__(self, input_broadcastable, new_order): self.augment = sorted(i for i, x in enumerate(new_order) if x == "x") self.drop = drop - input_ndim = len(input_broadcastable) self.is_left_expand_dims = self.augment and ( input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) ) @@ -204,30 +182,29 @@ def __setstate__(self, state): # Let's just build the ExternalCOp. super().__init__([self.c_func_file], self.c_func_name) - def make_node(self, _input): - input = as_tensor_variable(_input) - ib = tuple(s == 1 for s in input.type.shape) - if ib != self.input_broadcastable: - if len(ib) != len(self.input_broadcastable): + def make_node(self, inp): + input = as_tensor_variable(inp) + if input.type.ndim != self.input_ndim: + raise TypeError( + "The number of dimensions of the input is incorrect for this op. " + f"Expected {self.input_ndim}, got {input.type.ndim}." + ) + + input_static_shape = input.type.shape + + # Runtime check for invalid drop + for d in self.drop: + if input_static_shape[d] not in (1, None): raise TypeError( - "The number of dimensions of the " - f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}." + f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}" ) - for expected, b in zip(self.input_broadcastable, ib): - if expected and not b: - raise TypeError( - "The broadcastable pattern of the " - f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}." - ) - # else, expected == b or not expected and b - # Both case are good. out_static_shape = [] for dim_idx in self.new_order: if dim_idx == "x": out_static_shape.append(1) else: - out_static_shape.append(input.type.shape[dim_idx]) + out_static_shape.append(input_static_shape[dim_idx]) output = TensorType(dtype=input.type.dtype, shape=out_static_shape)() @@ -254,12 +231,14 @@ def perform(self, node, inp, out): if not isinstance(res, np.ndarray | np.memmap): raise TypeError(res) + # Put dropped axis at end res = res.transpose(self.transposition) - shape = list(res.shape[: len(self.shuffle)]) + # Define new shape without dropped axis and including new ones + new_shape = list(res.shape[: len(self.shuffle)]) for augm in self.augment: - shape.insert(augm, 1) - res = res.reshape(shape) + new_shape.insert(augm, 1) + res = res.reshape(new_shape) if not self.inplace: res = np.copy(res) @@ -284,22 +263,15 @@ def R_op(self, inputs, eval_points): def grad(self, inp, grads): (x,) = inp (gz,) = grads - gz = as_tensor_variable(gz) grad_order = ["x"] * x.type.ndim for i, v in enumerate(self.new_order): if v != "x": grad_order[v] = i - # Do not make the DimShuffle inplace as an optimization at the - # canonicalization optimization phase will remove the inplace. - # The inplace will be reintroduced automatically later in the graph. - if inp[0].dtype in discrete_dtypes: - return [inp[0].zeros_like(dtype=config.floatX)] + + if x.type.dtype in discrete_dtypes: + return [x.zeros_like(dtype=config.floatX)] else: - return [ - DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)( - Elemwise(scalar_identity)(gz) - ) - ] + return [gz.dimshuffle(grad_order)] class DimShufflePrinter(Printer): @@ -409,7 +381,7 @@ def __setstate__(self, d): self.nfunc = None self.inplace_pattern = frozendict(self.inplace_pattern) - def get_output_info(self, dim_shuffle, *inputs): + def get_output_info(self, *inputs): """Return the outputs dtype and broadcastable pattern and the dimshuffled inputs. @@ -427,12 +399,7 @@ def get_output_info(self, dim_shuffle, *inputs): if not difference: args.append(input) else: - args.append( - dim_shuffle( - input.type.broadcastable, - ["x"] * difference + list(range(length)), - )(input) - ) + args.append(input.dimshuffle(["x"] * difference + list(range(length)))) inputs = args # HERE: all the broadcast dims have the same length now @@ -489,7 +456,7 @@ def make_node(self, *inputs): using DimShuffle. """ inputs = [as_tensor_variable(i) for i in inputs] - out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs) + out_dtypes, out_shapes, inputs = self.get_output_info(*inputs) outputs = [ TensorType(dtype=dtype, shape=shape)() for dtype, shape in zip(out_dtypes, out_shapes) @@ -634,7 +601,7 @@ def transform(r): res = pytensor.tensor.basic.constant( np.asarray(r.data), dtype=r.type.dtype ) - return DimShuffle((), ["x"] * nd)(res) + return res.dimshuffle(["x"] * nd) new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs]) if isinstance(new_r, list | tuple): @@ -1707,13 +1674,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl batched_ndims = x.type.ndim - node.inputs[0].type.ndim if not batched_ndims: return node.op.make_node(x) - input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable - # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2)) - # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x")) + # e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2)) + # e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x")) new_order = list(range(batched_ndims)) + [ "x" if (o == "x") else (o + batched_ndims) for o in op.new_order ] - return DimShuffle(input_broadcastable, new_order).make_node(x) + return x.dimshuffle(new_order).owner def get_normalized_batch_axes( diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index cf809a55ef..91c6eb48f0 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -41,7 +41,6 @@ ) from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import sum as pt_sum -from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.variable import TensorVariable @@ -609,11 +608,6 @@ def squeeze(x, axis=None): # Nothing could be squeezed return _x - # `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable. - # We add a `specify_broadcastable` instead of raising. - non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]] - _x = specify_broadcastable(_x, *non_broadcastable_axis) - return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis]) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 86df161fb6..a47d2997cd 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -33,7 +33,6 @@ from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.elemwise import ( CAReduce, - DimShuffle, Elemwise, get_normalized_batch_axes, scalar_elemwise, @@ -2338,8 +2337,7 @@ def L_op(self, inp, out, grads): else: new_dims.append(i) i += 1 - ds_op = DimShuffle(gz.type.broadcastable, new_dims) - gx = Elemwise(ps.second)(x, ds_op(gz)) + gx = Elemwise(ps.second)(x, gz.dimshuffle(new_dims)) return [gx] def R_op(self, inputs, eval_points): diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 613fb80f3e..261a8bbc4a 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -344,8 +344,8 @@ def dimshuffle(self, *pattern): """ if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)): pattern = pattern[0] - op = pt.elemwise.DimShuffle(list(self.type.broadcastable), pattern) - return op(self) + ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern) + return ds_op(self) def flatten(self, ndim=1): return pt.basic.flatten(self, ndim) diff --git a/tests/tensor/test_fft.py b/tests/tensor/test_fft.py index 3599c97de3..d070a8da42 100644 --- a/tests/tensor/test_fft.py +++ b/tests/tensor/test_fft.py @@ -204,3 +204,12 @@ def f_irfft(inp): pytensor.config.floatX ) utt.verify_grad(f_irfft, [inputs_val], eps=eps) + + def test_rfft_expanded_dims_grad(self): + # Regression test for https://github.com/pymc-devs/pytensor/issues/969 + def test_func(x): + return fft.rfft(x[None]) + + rng = np.random.default_rng(213) + inputs_val = rng.random((N,)).astype(pytensor.config.floatX) + utt.verify_grad(test_func, [inputs_val], eps=1e-2)