Skip to content

Commit

Permalink
Only require input_ndim and not input_broadcastable in DimShuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 22, 2024
1 parent a9ed164 commit 0a667f6
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 106 deletions.
156 changes: 61 additions & 95 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Sequence
from copy import copy
from textwrap import dedent
from typing import Literal

import numpy as np
from numpy.core.numeric import normalize_axis_tuple
Expand Down Expand Up @@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
Parameters
----------
input_broadcastable
The expected broadcastable pattern of the input
input_ndim
The expected number of dimension of the input
new_order
A list representing the relationship between the input's
dimensions and the output's dimensions. Each element of the
list can either be an index or 'x'. Indices must be encoded
as python integers, not pytensor symbolic integers.
inplace : bool, optional
If True (default), the output will be a view of the input.
Missing indexes correspond to drop dimensions.
Notes
-----
Expand All @@ -77,50 +78,45 @@ class DimShuffle(ExternalCOp):
.. code-block:: python
DimShuffle((False, False, False), ['x', 2, 'x', 0, 1])
DimShuffle(input_ndim=3, new_order=['x', 2, 'x', 0, 1])
This `Op` will only work on 3d tensors with no broadcastable
dimensions. The first dimension will be broadcastable,
This `Op` will only work on 3d tensors.
The first dimension of the output will be broadcastable,
then we will have the third dimension of the input tensor as
the second of the resulting tensor, etc. If the tensor has
shape (20, 30, 40), the resulting tensor will have dimensions
(1, 40, 1, 20, 30). (AxBxC tensor is mapped to 1xCx1xAxB tensor)
.. code-block:: python
DimShuffle((True, False), [1])
DimShuffle(input_ndim=2, new_order=[1])
This `Op` will only work on 2d tensors with the first dimension
broadcastable.
The second dimension of the input tensor will be the first dimension of
the resulting tensor.
If the tensor has shape (1, 20), the resulting tensor will have shape
(20, ).
This `Op` will only work on 2d tensors with the first dimension broadcastable.
The second dimension of the input tensor will be the first dimension of the resulting tensor.
If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
Examples
--------
.. code-block:: python
DimShuffle((), ['x']) # make a 0d (scalar) into a 1d vector
DimShuffle((False, False), [0, 1]) # identity
DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
DimShuffle((False,), ['x', 0]) # make a row out of a 1d vector
# (N to 1xN)
DimShuffle((False,), [0, 'x']) # make a column out of a 1d vector
# (N to Nx1)
DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
DimShuffle((False, False), [0, 'x', 1]) # AxB to Ax1xB
DimShuffle((False, False), [1, 'x', 0]) # AxB to Bx1xA
The reordering of the dimensions can be done with the numpy.transpose
function.
Adding, subtracting dimensions can be done with reshape.
DimShuffle(input_ndim=0, new_order=['x']) # make a 0d (scalar) into a 1d vector
DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
DimShuffle(input_ndim=1, new_order=['x', 0]) # make a row out of a 1d vector (N to 1xN)
DimShuffle(input_ndim=1, new_order=[0, 'x']) # make a column out of a 1d vector (N to Nx1)
DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
DimShuffle(input_ndim=2, new_order=[0, 'x', 1]) # AxB to Ax1xB
DimShuffle(input_ndim=2, new_order=[1, 'x', 0]) # AxB to Bx1xA
Notes
-----
The python implementation of this Op combines numpy.transpose for reordering of the dimensions
and numpy.reshape for subtracting and adding broadcastable dimensions.
"""

_f16_ok = True
check_input = False
__props__ = ("input_broadcastable", "new_order", "inplace")
__props__ = ("input_ndim", "new_order", "inplace")
c_func_file = "c_code/dimshuffle.c"
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"

Expand All @@ -133,16 +129,11 @@ def params_type(self):
inplace=scalar_bool,
)

def __init__(self, input_broadcastable, new_order):
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
super().__init__([self.c_func_file], self.c_func_name)

self.input_broadcastable = tuple(input_broadcastable)
if not all(isinstance(bs, bool | np.bool_) for bs in self.input_broadcastable):
raise ValueError(
f"input_broadcastable must be boolean, {self.input_broadcastable}"
)
self.input_ndim = input_ndim
self.new_order = tuple(new_order)

self.inplace = True

for i, j in enumerate(new_order):
Expand All @@ -152,10 +143,10 @@ def __init__(self, input_broadcastable, new_order):
"DimShuffle indices must be Python ints; got "
f"{j} of type {type(j)}."
)
if j >= len(input_broadcastable):
if j >= input_ndim:
raise ValueError(
f"new_order[{i}] is {j}, but the input only has "
f"{len(input_broadcastable)} axes."
f"{input_ndim} axes."
)
if j in new_order[(i + 1) :]:
raise ValueError(
Expand All @@ -164,19 +155,7 @@ def __init__(self, input_broadcastable, new_order):
)

# List of input dimensions to drop
drop = []
for i, b in enumerate(input_broadcastable):
if i not in new_order:
# We want to drop this dimension because it's not a value in
# `new_order`
if b == 1:
drop.append(i)
else:
# We cannot drop non-broadcastable dimensions
raise ValueError(
"Cannot drop a non-broadcastable dimension: "
f"{input_broadcastable}, {new_order}"
)
drop = [i for i in range(input_ndim) if i not in new_order]

# This is the list of the original dimensions that we keep
self.shuffle = [x for x in new_order if x != "x"]
Expand All @@ -186,7 +165,6 @@ def __init__(self, input_broadcastable, new_order):
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop

input_ndim = len(input_broadcastable)
self.is_left_expand_dims = self.augment and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
)
Expand All @@ -204,30 +182,29 @@ def __setstate__(self, state):
# Let's just build the ExternalCOp.
super().__init__([self.c_func_file], self.c_func_name)

def make_node(self, _input):
input = as_tensor_variable(_input)
ib = tuple(s == 1 for s in input.type.shape)
if ib != self.input_broadcastable:
if len(ib) != len(self.input_broadcastable):
def make_node(self, inp):
input = as_tensor_variable(inp)
if input.type.ndim != self.input_ndim:
raise TypeError(
"The number of dimensions of the input is incorrect for this op. "
f"Expected {self.input_ndim}, got {input.type.ndim}."
)

input_static_shape = input.type.shape

# Runtime check for invalid drop
for d in self.drop:
if input_static_shape[d] not in (1, None):
raise TypeError(
"The number of dimensions of the "
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}"
)
for expected, b in zip(self.input_broadcastable, ib):
if expected and not b:
raise TypeError(
"The broadcastable pattern of the "
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
)
# else, expected == b or not expected and b
# Both case are good.

out_static_shape = []
for dim_idx in self.new_order:
if dim_idx == "x":
out_static_shape.append(1)
else:
out_static_shape.append(input.type.shape[dim_idx])
out_static_shape.append(input_static_shape[dim_idx])

output = TensorType(dtype=input.type.dtype, shape=out_static_shape)()

Expand All @@ -254,12 +231,14 @@ def perform(self, node, inp, out):
if not isinstance(res, np.ndarray | np.memmap):
raise TypeError(res)

# Put dropped axis at end
res = res.transpose(self.transposition)

shape = list(res.shape[: len(self.shuffle)])
# Define new shape without dropped axis and including new ones
new_shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
res = res.reshape(shape)
new_shape.insert(augm, 1)
res = res.reshape(new_shape)

if not self.inplace:
res = np.copy(res)
Expand All @@ -284,22 +263,15 @@ def R_op(self, inputs, eval_points):
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
gz = as_tensor_variable(gz)
grad_order = ["x"] * x.type.ndim
for i, v in enumerate(self.new_order):
if v != "x":
grad_order[v] = i
# Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph.
if inp[0].dtype in discrete_dtypes:
return [inp[0].zeros_like(dtype=config.floatX)]

if x.type.dtype in discrete_dtypes:
return [x.zeros_like(dtype=config.floatX)]
else:
return [
DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
Elemwise(scalar_identity)(gz)
)
]
return [gz.dimshuffle(grad_order)]


class DimShufflePrinter(Printer):
Expand Down Expand Up @@ -409,7 +381,7 @@ def __setstate__(self, d):
self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern)

def get_output_info(self, dim_shuffle, *inputs):
def get_output_info(self, *inputs):
"""Return the outputs dtype and broadcastable pattern and the
dimshuffled inputs.
Expand All @@ -427,12 +399,7 @@ def get_output_info(self, dim_shuffle, *inputs):
if not difference:
args.append(input)
else:
args.append(
dim_shuffle(
input.type.broadcastable,
["x"] * difference + list(range(length)),
)(input)
)
args.append(input.dimshuffle(["x"] * difference + list(range(length))))
inputs = args

# HERE: all the broadcast dims have the same length now
Expand Down Expand Up @@ -489,7 +456,7 @@ def make_node(self, *inputs):
using DimShuffle.
"""
inputs = [as_tensor_variable(i) for i in inputs]
out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
out_dtypes, out_shapes, inputs = self.get_output_info(*inputs)
outputs = [
TensorType(dtype=dtype, shape=shape)()
for dtype, shape in zip(out_dtypes, out_shapes)
Expand Down Expand Up @@ -634,7 +601,7 @@ def transform(r):
res = pytensor.tensor.basic.constant(
np.asarray(r.data), dtype=r.type.dtype
)
return DimShuffle((), ["x"] * nd)(res)
return res.dimshuffle(["x"] * nd)

new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
if isinstance(new_r, list | tuple):
Expand Down Expand Up @@ -1707,13 +1674,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
if not batched_ndims:
return node.op.make_node(x)
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
# e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
# e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
new_order = list(range(batched_ndims)) + [
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
]
return DimShuffle(input_broadcastable, new_order).make_node(x)
return x.dimshuffle(new_order).owner


def get_normalized_batch_axes(
Expand Down
6 changes: 0 additions & 6 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
)
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -609,11 +608,6 @@ def squeeze(x, axis=None):
# Nothing could be squeezed
return _x

# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
_x = specify_broadcastable(_x, *non_broadcastable_axis)

return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])


Expand Down
4 changes: 1 addition & 3 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import (
CAReduce,
DimShuffle,
Elemwise,
get_normalized_batch_axes,
scalar_elemwise,
Expand Down Expand Up @@ -2338,8 +2337,7 @@ def L_op(self, inp, out, grads):
else:
new_dims.append(i)
i += 1
ds_op = DimShuffle(gz.type.broadcastable, new_dims)
gx = Elemwise(ps.second)(x, ds_op(gz))
gx = Elemwise(ps.second)(x, gz.dimshuffle(new_dims))
return [gx]

def R_op(self, inputs, eval_points):
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ def dimshuffle(self, *pattern):
"""
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)):
pattern = pattern[0]
op = pt.elemwise.DimShuffle(list(self.type.broadcastable), pattern)
return op(self)
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
return ds_op(self)

def flatten(self, ndim=1):
return pt.basic.flatten(self, ndim)
Expand Down
9 changes: 9 additions & 0 deletions tests/tensor/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,12 @@ def f_irfft(inp):
pytensor.config.floatX
)
utt.verify_grad(f_irfft, [inputs_val], eps=eps)

def test_rfft_expanded_dims_grad(self):
# Regression test for https://github.com/pymc-devs/pytensor/issues/969
def test_func(x):
return fft.rfft(x[None])

rng = np.random.default_rng(213)
inputs_val = rng.random((N,)).astype(pytensor.config.floatX)
utt.verify_grad(test_func, [inputs_val], eps=1e-2)

0 comments on commit 0a667f6

Please sign in to comment.