diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index f4098416b8..00976f221c 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -4,6 +4,7 @@ # Load dispatch specializations import pytensor.link.jax.dispatch.blas import pytensor.link.jax.dispatch.blockwise +import pytensor.link.jax.dispatch.einsum import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.extra_ops import pytensor.link.jax.dispatch.pad diff --git a/pytensor/link/jax/dispatch/einsum.py b/pytensor/link/jax/dispatch/einsum.py new file mode 100644 index 0000000000..3080f6964f --- /dev/null +++ b/pytensor/link/jax/dispatch/einsum.py @@ -0,0 +1,20 @@ +import jax.numpy as jnp + +from pytensor.link.jax.dispatch import jax_funcify +from pytensor.tensor.einsum import Einsum + + +@jax_funcify.register(Einsum) +def jax_funcify_Einsum(op, **kwargs): + """Dispatch einsum to JAX. + + This dispatch is triggered only when we couldn't optimize einsum at the PyTensor level. + This happens when some of the dimension lengths are unknown. This is never a problem in JAX, + as it always compiles a function per runtime input shape. + """ + subscripts = op.subscripts + + def einsum(*operands): + return jnp.einsum(subscripts, *operands, optimize="optimal") + + return einsum diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 81cabfa6bd..7385f02478 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -151,6 +151,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: # isort: off +from pytensor.tensor.einsum import einsum from pytensor.tensor.functional import vectorize # isort: on diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 119c44c647..9eaa04c522 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1700,21 +1700,22 @@ def do_constant_folding(self, fgraph, node): return False for client, idx in clients: - if isinstance(client.op, Output): + client_op = client.op + if isinstance(client_op, Output): # If the output is a constant, it will have to be deepcopied # each time the function is called. So we do not fold. return False - # Allow alloc to be lifted out of Elemwise before constant folding it - elif isinstance(client.op, Elemwise): - return None + # Op's through which Alloc can be lifted + elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join): + return False # Same for Blockwise, unless it has no batch_dims - elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client): - return None + elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client): + return False elif ( # The following ops work inplace of their input id 0. idx == 0 and isinstance( - client.op, + client_op, pytensor.tensor.subtensor.IncSubtensor | pytensor.tensor.subtensor.AdvancedIncSubtensor1 | pytensor.tensor.subtensor.AdvancedIncSubtensor @@ -2035,10 +2036,15 @@ def transpose(x, axes=None): _x = as_tensor_variable(x) if axes is None: - axes = list(range((_x.type.ndim - 1), -1, -1)) + axes = tuple(range((_x.type.ndim - 1), -1, -1)) + + if tuple(axes) == tuple(range(len(axes))): + # No-op + return _x + ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x) - if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)): + if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)): ret.name = _x.name + ".T" return ret @@ -3950,6 +3956,10 @@ def moveaxis( source = normalize_axis_tuple(source, a.ndim, "source") destination = normalize_axis_tuple(destination, a.ndim, "destination") + if source == destination: + # It's a no-op + return a + if len(source) != len(destination): raise ValueError( "`source` and `destination` arguments must have the same number of elements" @@ -4260,9 +4270,7 @@ def atleast_Nd( atleast_3d = partial(atleast_Nd, n=3) -def expand_dims( - a: np.ndarray | TensorVariable, axis: tuple[int, ...] -) -> TensorVariable: +def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable: """Expand the shape of an array. Insert a new axis that will appear at the `axis` position in the expanded @@ -4281,7 +4289,7 @@ def expand_dims( """ a = as_tensor(a) - if not isinstance(axis, tuple | list): + if not isinstance(axis, Sequence): axis = (axis,) out_ndim = len(axis) + a.ndim diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py new file mode 100644 index 0000000000..278318de33 --- /dev/null +++ b/pytensor/tensor/einsum.py @@ -0,0 +1,585 @@ +import collections +import warnings +from collections.abc import Sequence +from functools import partial, reduce +from itertools import pairwise +from typing import cast + +import numpy as np +from numpy.core.einsumfunc import _find_contraction, _parse_einsum_input # type: ignore +from numpy.core.numeric import ( # type: ignore + normalize_axis_index, + normalize_axis_tuple, +) + +from pytensor.compile.builders import OpFromGraph +from pytensor.tensor import TensorLike +from pytensor.tensor.basic import ( + arange, + as_tensor, + get_vector_length, + moveaxis, + stack, + transpose, + where, +) +from pytensor.tensor.extra_ops import broadcast_to +from pytensor.tensor.functional import vectorize +from pytensor.tensor.math import and_, eq, tensordot +from pytensor.tensor.shape import shape_padright +from pytensor.tensor.variable import TensorVariable + + +PATH = tuple[tuple[int] | tuple[int, int], ...] + + +class Einsum(OpFromGraph): + """ + Wrapper Op for Einsum graphs + + Notes + ----- + The `optimized` prop indicates whether the inner graph was optimized, which can only be done when all shapes are + statically known. This is now determined at graph creation time only. We could introduce a rewrite that tries to + optimize the graph if static shapes become known later (e.g., after use of `clone_replace` or shape inference during + rewrites). + + Also, once the graph is optimized, it could be inlined for potential further optimization that consider the rest of + the graph. + + This prop is different from the `optimize` kwarg in numpy that determines what kind (if any) of optimization is + desired. We haven't decided whether we want to provide this functionality. + """ + + __props__ = ("subscripts", "path", "optimized") + + def __init__(self, *args, subscripts: str, path: PATH, optimized: bool, **kwargs): + self.subscripts = subscripts + self.path = path + self.optimized = optimized + super().__init__(*args, **kwargs, strict=True) + + +def _iota(shape: TensorVariable, axis: int) -> TensorVariable: + len_shape = get_vector_length(shape) + axis = normalize_axis_index(axis, len_shape) + values = arange(shape[axis]) + return broadcast_to(shape_padright(values, len_shape - axis - 1), shape) + + +def _delta(shape, axes: Sequence[int]) -> TensorVariable: + """This utility function exists for creating Kronecker delta arrays.""" + base_shape = stack([shape[axis] for axis in axes]) + iotas = [_iota(base_shape, i) for i in range(len(axes))] + eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)] + result = reduce(and_, eyes) + return broadcast_to(result, shape) + + +def _general_dot( + vars: tuple[TensorVariable, TensorVariable], + axes: Sequence[Sequence[int]], # Should be length 2, + batch_axes: Sequence[Sequence[int]], # Should be length 2, +) -> TensorVariable: + # Shortcut for non batched case + if not batch_axes[0] and not batch_axes[1]: + return tensordot(*vars, axes=axes) + + # Normalize axes, thankfully numpy helper does not sort axis! + axes = [ + normalize_axis_tuple(var_axes, var.ndim) + for var, var_axes in zip(vars, axes, strict=True) + ] + batch_axes = [ + normalize_axis_tuple(var_axes, var.ndim) + for var, var_axes in zip(vars, batch_axes, strict=True) + ] + n_batch_axes = [len(var_batch_axes) for var_batch_axes in batch_axes] + + # Move batch axes to the left and recode reduction axes + new_vars = list(vars) + new_axes = list(axes) + for i, (var, var_axes, var_batch_axes, var_n_batch_axes) in enumerate( + zip(vars, axes, batch_axes, n_batch_axes, strict=True) + ): + if var_batch_axes == tuple(range(var_n_batch_axes)): + # Already on left to right order + continue + + new_var_batch_axes = tuple(range(var_n_batch_axes)) + new_var = moveaxis(var, var_batch_axes, new_var_batch_axes) + + new_var_axes = [] + for var_axis in var_axes: + batch_axes_to_the_right = len( + [batch_axis for batch_axis in var_batch_axes if batch_axis > var_axis] + ) + new_var_axes.append(var_axis + batch_axes_to_the_right) + + new_vars[i] = new_var + new_axes[i] = new_var_axes + + lhs, rhs = new_vars + lhs_axes, rhs_axes = new_axes + lhs_n_batch_axes, rhs_n_batch_axes = n_batch_axes + + # Create signature of tensordot + lhs_signature = [f"l{i}" for i in range(lhs.type.ndim)] + rhs_signature = [f"r{i}" for i in range(rhs.type.ndim)] + # Aligned axes get the same dimension name + for i, (lhs_axis, rhs_axis) in enumerate(zip(lhs_axes, rhs_axes)): + lhs_signature[lhs_axis] = rhs_signature[rhs_axis] = f"a{i}" + # Trim away the batch ndims + lhs_signature = lhs_signature[lhs_n_batch_axes:] + rhs_signature = rhs_signature[rhs_n_batch_axes:] + out_signature = [ + lhs_dim for lhs_dim in lhs_signature if not lhs_dim.startswith("a") + ] + [rhs_dim for rhs_dim in rhs_signature if not rhs_dim.startswith("a")] + signature = f"({','.join(lhs_signature)}),({','.join(rhs_signature)})->({','.join(out_signature)})" + # Adjust axes for core case + core_lhs_axes = tuple(np.array(lhs_axes) - lhs_n_batch_axes) + core_rhs_axes = tuple(np.array(rhs_axes) - rhs_n_batch_axes) + + if signature == "(),()->()": + # Just a multiplication + out = lhs * rhs + else: + out = vectorize( + partial(tensordot, axes=[core_lhs_axes, core_rhs_axes]), signature=signature + )(lhs, rhs) + + return cast(TensorVariable, out) + + +def _contraction_list_from_path( + subscripts: str, operands: Sequence[TensorVariable], path: PATH +): + """ + Generate a list of contraction steps based on the provided einsum path. + + Code adapted from einsum_opt: https://github.com/dgasmith/opt_einsum/blob/94c62a05d5ebcedd30f59c90b9926de967ed10b5/opt_einsum/contract.py#L369 + + When all shapes are known, the linked einsum_opt implementation is preferred. This implementation is used when + some or all shapes are not known. As a result, contraction will (always?) be done left-to-right, pushing intermediate + results to the end of the stack. + + Parameters + ---------- + subscripts: str + Einsum signature string describing the computation to be performed. + + operands: Sequence[TensorLike] + Tensors described by the subscripts. + + path: tuple[tuple[int] | tuple[int, int]] + A list of tuples, where each tuple describes the indices of the operands to be contracted, sorted in the order + they should be contracted. + + Returns + ------- + contraction_list: list + A list of tuples, where each tuple describes a contraction step. Each tuple contains the following elements: + - contraction_inds: tuple[int] + The indices of the operands to be contracted + - idx_removed: str + The indices of the contracted indices (those removed from the einsum string at this step) + - einsum_str: str + The einsum string for the contraction step + - remaining: None + The remaining indices. Included to match the output of opt_einsum.contract_path, but not used. + - do_blas: None + Whether to use blas to perform this step. Included to match the output of opt_einsum.contract_path, + but not used. + """ + fake_operands = [ + np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands + ] + input_subscripts, output_subscript, operands = _parse_einsum_input( + (subscripts, *fake_operands) + ) + + # Build a few useful list and sets + input_list = input_subscripts.split(",") + input_sets = [set(x) for x in input_list] + output_set = set(output_subscript) + + # Build contraction tuple (positions, gemm, einsum_str, remaining) + contraction_list = [] + for cnum, contract_inds in enumerate(path): + # Make sure we remove inds from right to left + contract_inds = cast( + tuple[int] | tuple[int, int], tuple(sorted(contract_inds, reverse=True)) + ) + + contract_tuple = _find_contraction(contract_inds, input_sets, output_set) + out_inds, input_sets, idx_removed, idx_contract = contract_tuple + + tmp_inputs = [input_list.pop(x) for x in contract_inds] + + # Last contraction + if (cnum - len(path)) == -1: + idx_result = output_subscript + else: + # use tensordot order to minimize transpositions + all_input_inds = "".join(tmp_inputs) + idx_result = "".join(sorted(out_inds, key=all_input_inds.find)) + + input_list.append(idx_result) + einsum_str = ",".join(tmp_inputs) + "->" + idx_result + + # We only need the first three inputs to build the forward graph + contraction = (contract_inds, idx_removed, einsum_str, None, None) + contraction_list.append(contraction) + + return contraction_list + + +def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable: + """ + Multiplication and summation of tensors using the Einstein summation convention. + + Code adapted from JAX: https://github.com/google/jax/blob/534d32a24d7e1efdef206188bb11ae48e9097092/jax/_src/numpy/lax_numpy.py#L5283 + + Einsum allows the user to specify a wide range of operations on tensors using the Einstein summation convention. Using + this notation, many common linear algebraic operations can be succinctly described on higher order tensors. + + Parameters + ---------- + subscripts: str + Einsum signature string describing the computation to be performed. + + operands: sequence of TensorVariable + Tensors to be multiplied and summed. + + Returns + ------- + TensorVariable + The result of the einsum operation. + + See Also + -------- + pytensor.tensor.tensordot: Generalized dot product between two tensors + pytensor.tensor.dot: Matrix multiplication between two tensors + numpy.einsum: The numpy implementation of einsum + + Examples + -------- + Inputs to `pt.einsum` are a string describing the operation to be performed (the "subscripts"), and a sequence of + tensors to be operated on. The string must follow the following rules: + + 1. The string gives inputs and (optionally) outputs. Inputs and outputs are separated by "->". + 2. The input side of the string is a comma-separated list of indices. For each comma-separated index string, there + must be a corresponding tensor in the input sequence. + 3. For each index string, the number of dimensions in the corresponding tensor must match the number of characters + in the index string. + 4. Indices are arbitrary strings of characters. If an index appears multiple times in the input side, it must have + the same shape in each input. + 5. The indices on the output side must be a subset of the indices on the input side -- you cannot introduce new + indices in the output. + 6. Elipses ("...") can be used to elide multiple indices. This is useful when you have a large number of "batch" + dimensions that are not implicated in the operation. + + Finally, two rules about these indicies govern how computation is carried out: + + 1. Repeated indices on the input side indicate how the tensor should be "aligned" for multiplication. + 2. Indices that appear on the input side but not the output side are summed over. + + The operation of these rules is best understood via examples: + + Example 1: Matrix multiplication + + .. code-block:: python + + import pytensor as pt + A = pt.matrix("A") + B = pt.matrix("B") + C = pt.einsum("ij, jk -> ik", A, B) + + This computation is equivalent to :code:`C = A @ B`. Notice that the ``j`` index is repeated on the input side of the + signature, and does not appear on the output side. This indicates that the ``j`` dimension of the first tensor should be + multiplied with the ``j`` dimension of the second tensor, and the resulting tensor's ``j`` dimension should be summed + away. + + Example 2: Batched matrix multiplication + + .. code-block:: python + + import pytensor as pt + A = pt.tensor("A", shape=(None, 4, 5)) + B = pt.tensor("B", shape=(None, 5, 6)) + C = pt.einsum("bij, bjk -> bik", A, B) + + This computation is also equivalent to :code:`C = A @ B` because of Pytensor's built-in broadcasting rules, but + the einsum signature is more explicit about the batch dimensions. The ``b`` and ``j`` indices are repeated on the + input side. Unlike ``j``, the ``b`` index is also present on the output side, indicating that the batch dimension + should **not** be summed away. As a result, multiplication will be performed over the ``b, j`` dimensions, and then + the ``j`` dimension will be summed over. The resulting tensor will have shape ``(None, 4, 6)``. + + Example 3: Batched matrix multiplication with elipses + + .. code-block:: python + + import pytensor as pt + A = pt.tensor("A", shape=(4, None, None, None, 5)) + B = pt.tensor("B", shape=(5, None, None, None, 6)) + C = pt.einsum("i...j, j...k -> ...ik", A, B) + + This case is the same as above, but inputs ``A`` and ``B`` have multiple batch dimensions. To avoid writing out all + of the batch dimensions (which we do not care about), we can use ellipses to elide over these dimensions. Notice + also that we are not required to "sort" the input dimensions in any way. In this example, we are doing a dot + between the last dimension A and the first dimension of B, which is perfectly valid. + + Example 4: Outer product + + .. code-block:: python + + import pytensor as pt + x = pt.tensor("x", shape=(3,)) + y = pt.tensor("y", shape=(4,)) + z = pt.einsum("i, j -> ij", x, y) + + This computation is equivalent to :code:`pt.outer(x, y)`. Notice that no indices are repeated on the input side, + and the output side has two indices. Since there are no indices to align on, the einsum operation will simply + multiply the two tensors elementwise, broadcasting dimensions ``i`` and ``j``. + + Example 5: Convolution + + .. code-block:: python + + import pytensor as pt + x = pt.tensor("x", shape=(None, None, None, None, None, None)) + w = pt.tensor("w", shape=(None, None, None, None)) + y = pt.einsum(""bchwkt,fckt->bfhw", x, w) + + Given a batch of images ``x`` with dimensions ``(batch, channel, height, width, kernel_size, num_filters)`` + and a filter ``w``, with dimensions ``(num_filters, channels, kernel_size, num_filters)``, this einsum operation + computes the convolution of ``x`` with ``w``. Multiplication is aligned on the batch, num_filters, height, and width + dimensions. The channel, kernel_size, and num_filters dimensions are summed over. The resulting tensor has shape + ``(batch, num_filters, height, width)``, reflecting the fact that information from each channel has been mixed + together. + """ + + if optimize is not None: + raise NotImplementedError( + "Optimize kwarg is not implemented in PyTensor. " + "By default, PyTensor will always optimize the graph if the inputs have static shapes.\n" + "If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. " + ) + + # TODO: Is this doing something clever about unknown shapes? + # contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) + tensor_operands = [as_tensor(operand) for operand in operands] + shapes = [operand.type.shape for operand in tensor_operands] + + path: PATH + if any(None in shape for shape in shapes): + # Case 1: At least one of the operands has an unknown shape. In this case, we can't use opt_einsum to optimize + # the contraction order, so we just use a default path of (1,0) contractions. This will work left-to-right, + # pushing intermediate results to the end of the stack. + # We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will + # match more often + + # If shapes become known later we will likely want to rebuild the Op (unless we inline it) + if len(tensor_operands) == 1: + path = ((0,),) + else: + # By default, we try right to left because we assume that most graphs + # have a lower dimensional rightmost operand + path = tuple(pairwise(reversed(range(len(tensor_operands))))) + contraction_list = _contraction_list_from_path( + subscripts, tensor_operands, path + ) + + # If there are only 1 or 2 operands, there is no optimization to be done? + optimized = len(tensor_operands) <= 2 + else: + # Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal + # contraction order. + _, contraction_list = np.einsum_path( + subscripts, + # Numpy einsum_path requires arrays even though only the shapes matter + # It's not trivial to duck-type our way around because of internal call to `asanyarray` + *[np.empty(shape) for shape in shapes], + einsum_call=True, # Not part of public API + optimize="optimal", + ) # type: ignore + path = tuple(contraction[0] for contraction in contraction_list) + optimized = True + + def removechars(s, chars): + return s.translate(str.maketrans(dict.fromkeys(chars))) + + def sum_uniques( + operand: TensorVariable, names: str, uniques: list[str] + ) -> tuple[TensorVariable, str]: + """Reduce unique indices (those that appear only once) in a given contraction step via summing.""" + if uniques: + axes = [names.index(name) for name in uniques] + operand = operand.sum(axes) + names = removechars(names, uniques) + return operand, names + + def sum_repeats( + operand: TensorVariable, + names: str, + counts: collections.Counter, + keep_names: str, + ) -> tuple[TensorVariable, str]: + """Reduce repeated indices in a given contraction step via summation against an identity matrix.""" + + for name, count in counts.items(): + if count > 1: + axes = [i for i, n in enumerate(names) if n == name] + eye = _delta(operand.shape, axes) + operand = where(eye, operand, operand.zeros_like()) + if name not in keep_names: + operand = operand.sum(axes) + names = names.replace(name, "") + else: + operand = operand.sum(axes[:-1]) + names = names.replace(name, "", count - 1) + return operand, names + + def filter_singleton_dims(operand, names, other_operand, other_names): + op_bcast = operand.type.broadcastable + other_bcast = other_operand.type.broadcastable + keep = [ + (not op_bcast[i]) or (j == -1) or other_bcast[j] + for i, j in enumerate(map(other_names.find, names)) + ] + keep_axes = [i for i, keep_axis in enumerate(keep) if keep_axis] + squeeze_axes = [i for i, keep_axis in enumerate(keep) if not keep_axis] + if squeeze_axes: + # TODO: We could modify the subscripts to avoid the problem? + warnings.warn( + "The same einsum subscript is used for a broadcastable and non-broadcastable dimension. " + "This can result in a suboptimal contraction path." + ) + return operand.squeeze(squeeze_axes), "".join(names[i] for i in keep_axes) + + einsum_operands = list(tensor_operands) # So we can pop + for operand_indices, contracted_names, einstr, _, _ in contraction_list: + contracted_names = sorted(contracted_names) + assert len(contracted_names) == len( + set(contracted_names) + ), "The set was needed!" + + input_str, result_names = einstr.split("->") + input_names = input_str.split(",") + + # switch on the number of operands to be processed in this loop iteration. + # every case here sets 'operand' and 'names'. + if len(operand_indices) == 1: + operand = einsum_operands.pop(operand_indices[0]) + (names,) = input_names + counts = collections.Counter(names) + + # sum out unique contracted indices with a single reduce-sum + uniques = [name for name in contracted_names if counts[name] == 1] + operand, names = sum_uniques(operand, names, uniques) + + # for every repeated index, do a contraction against an identity matrix + operand, names = sum_repeats(operand, names, counts, result_names) + + elif len(operand_indices) == 2: + lhs, rhs = map(einsum_operands.pop, operand_indices) + lhs_names, rhs_names = input_names + + # handle cases where one side of a contracting or batch dimension is 1 + # but its counterpart is not. + lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, rhs, rhs_names) + rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, lhs, lhs_names) + + lhs_counts = collections.Counter(lhs_names) + rhs_counts = collections.Counter(rhs_names) + + # sum out unique contracted indices in lhs and rhs + lhs_uniques = [ + name + for name in contracted_names + if lhs_counts[name] == 1 and rhs_counts[name] == 0 + ] + lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques) + + rhs_uniques = [ + name + for name in contracted_names + if rhs_counts[name] == 1 and lhs_counts[name] == 0 + ] + rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques) + + # for every repeated index, contract against an identity matrix + lhs, lhs_names = sum_repeats( + lhs, lhs_names, lhs_counts, result_names + rhs_names + ) + rhs, rhs_names = sum_repeats( + rhs, rhs_names, rhs_counts, result_names + lhs_names + ) + + lhs_or_rhs_names = set(lhs_names) | set(rhs_names) + contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names] + lhs_and_rhs_names = set(lhs_names) & set(rhs_names) + batch_names = [x for x in result_names if x in lhs_and_rhs_names] + + if batch_names: + lhs_batch, rhs_batch = tuple( + zip(*[(lhs_names.find(n), rhs_names.find(n)) for n in batch_names]) + ) + else: + lhs_batch = rhs_batch = () + + # contract using dot_general + batch_names_str = "".join(batch_names) + if contracted_names: + lhs_cont, rhs_cont = tuple( + zip( + *[ + (lhs_names.index(n), rhs_names.index(n)) + for n in contracted_names + ] + ) + ) + else: + lhs_cont = rhs_cont = () + deleted_names = batch_names_str + "".join(contracted_names) + remaining_lhs_names = removechars(lhs_names, deleted_names) + remaining_rhs_names = removechars(rhs_names, deleted_names) + # Try both orders of lhs and rhs, in the hope that one of them means we + # don't need an explicit transpose. opt_einsum likes to contract from + # right to left, so we expect (rhs,lhs) to have the best chance of not + # needing a transpose. + names = batch_names_str + remaining_rhs_names + remaining_lhs_names + if names == result_names: + operand = _general_dot( + (rhs, lhs), (rhs_cont, lhs_cont), (rhs_batch, lhs_batch) + ) + else: + names = batch_names_str + remaining_lhs_names + remaining_rhs_names + operand = _general_dot( + (lhs, rhs), + axes=(lhs_cont, rhs_cont), + batch_axes=(lhs_batch, rhs_batch), + ) + else: + raise ValueError( + f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}" + ) + + # the resulting 'operand' with axis labels 'names' should be a permutation of the desired result + assert len(names) == len(result_names) == len(set(names)) + assert set(names) == set(result_names) + if names != result_names: + perm = tuple(names.index(name) for name in result_names) + operand = transpose(operand, perm) + einsum_operands.append(operand) # used in next iteration + + [einsum_result] = einsum_operands + + out = Einsum( + subscripts=subscripts, + inputs=list(tensor_operands), + outputs=[einsum_result], + path=tuple(path), + optimized=optimized, + )(*tensor_operands) + return cast(TensorVariable, out) diff --git a/pytensor/tensor/functional.py b/pytensor/tensor/functional.py index e7a5371b02..05e11f2643 100644 --- a/pytensor/tensor/functional.py +++ b/pytensor/tensor/functional.py @@ -1,8 +1,8 @@ from collections.abc import Callable from pytensor.graph import vectorize_graph -from pytensor.tensor import TensorVariable from pytensor.tensor.utils import _parse_gufunc_signature +from pytensor.tensor.variable import TensorVariable def vectorize(func: Callable, signature: str | None = None) -> Callable: diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 168b636041..fc5c528f2d 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -3,10 +3,9 @@ import pytensor.tensor.rewriting.blas_c import pytensor.tensor.rewriting.blas_scipy import pytensor.tensor.rewriting.blockwise +import pytensor.tensor.rewriting.einsum import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.extra_ops - -# Register JAX specializations import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.math diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 4a7570dad3..6a038cab15 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -52,6 +52,7 @@ TensorFromScalar, alloc, as_tensor_variable, + atleast_Nd, cast, extract_constant, fill, @@ -1219,3 +1220,123 @@ def local_merge_alloc(fgraph, node): register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy") + + +@register_specialize +@node_rewriter([DimShuffle]) +def local_dimshuffle_alloc(fgraph, node): + """ + Lift DimShuffle through Alloc + + dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2) + """ + alloc_out = node.inputs[0] + alloc_node = alloc_out.owner + if not (alloc_node and isinstance(alloc_node.op, Alloc)): + return + + ds_op = node.op + value, *alloc_shape = alloc_node.inputs + + # Add implicit dimensions of value + value = atleast_Nd(value, n=len(alloc_shape)) + + # Dimshuffle value and alloc_shape + ds_value = value.dimshuffle(ds_op.new_order) + ds_alloc_shape = [alloc_shape[i] for i in ds_op.shuffle] + for dim in ds_op.augment: + ds_alloc_shape.insert(dim, 1) + + return [alloc(ds_value, *ds_alloc_shape)] + + +@register_specialize("shape_unsafe") +@node_rewriter([Join]) +def local_join_of_alloc(fgraph, node): + """Rewrite a Join of Alloc nodes to an Alloc of the Join nodes.""" + axis, *tensors = node.inputs + + if len(tensors) < 2: + # Let other rewrite handle the useless Join + return + + if not isinstance(axis, Constant): + return + + core_tensors = [] + alloc_shapes = [] + for tensor in tensors: + if tensor.owner is None: + return + + # tensor = expand_dims_to_alloc(tensor) + if not isinstance(tensor.owner.op, Alloc): + return + + value, *shape = tensor.owner.inputs + # Introduce explicit batch dims + value = atleast_Nd(value, n=len(shape)) + core_tensors.append(value) + alloc_shapes.append(shape) + + # Find which allocated dimensions can be lifted + # Axis can never be lifted + # Non-axis allocated dimensions can be lifted if they are all broadcastable + [out] = node.outputs + axis = axis.data + + broadcasted_dims = list( + zip( + *( + [ + bef and not aft + for bef, aft in zip( + core_tensor.type.broadcastable, + tensor.type.broadcastable, + strict=True, + ) + ] + for core_tensor, tensor in zip(core_tensors, tensors, strict=True) + ) + ) + ) + + lifteable_alloc_dims = { + dim + for dim in range(out.type.ndim) + if dim != axis and all(broadcasted_dims[dim]) + } + + if not lifteable_alloc_dims: + return + + # Lift the allocated dimensions + new_tensors = [] + for core_tensor, alloc_shape in zip(core_tensors, alloc_shapes): + pre_join_shape = [ + 1 if i in lifteable_alloc_dims else alloc_dim + for i, alloc_dim in enumerate(alloc_shape) + ] + new_tensor = alloc(core_tensor, *pre_join_shape) + copy_stack_trace(tensor, new_tensor) + new_tensors.append(new_tensor) + + new_join = node.op(axis, *new_tensors) + copy_stack_trace(node.outputs[0], new_join) + + # Reintroduce the lifted dims + post_join_shape = [] + for i, alloc_dims in enumerate(zip(*alloc_shapes)): + if i == axis: + # The alloc dim along the axis is the sum of all the pre-join alloc dims + post_join_shape.append(add(*alloc_dims)) + else: + # Otherwise the shapes should all match. We prioritize constants if any + for best_alloc_dim in alloc_dims: + if isinstance(best_alloc_dim, Constant): + break + post_join_shape.append(best_alloc_dim) + + new_out = alloc(new_join, *post_join_shape) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 0bed304c29..7220824c58 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -10,6 +10,7 @@ register_specialize, register_stabilize, ) +from pytensor.tensor.shape import Reshape from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor @@ -67,10 +68,16 @@ def local_useless_unbatched_blockwise(fgraph, node): def local_eager_useless_unbatched_blockwise(fgraph, node): if isinstance( node.op.core_op, - Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor, + Dot + | Alloc + | ARange + | Subtensor + | AdvancedSubtensor + | AdvancedIncSubtensor + | Reshape, ): # Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize - # These other Ops can't always be trivially vectored at runtime, + # These other Ops can't always be trivially vectorized at runtime, # since their inputs may imply non-rectangular shapes. return local_useless_unbatched_blockwise.fn(fgraph, node) @@ -97,62 +104,67 @@ def local_blockwise_alloc(fgraph, node): BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector) """ - if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner): - return None - op: Blockwise = node.op # type: ignore batch_ndim = op.batch_ndim(node) if not batch_ndim: return None + if not any(var.owner and isinstance(var.owner.op, Alloc) for var in node.inputs): + return None + new_inputs = [] batch_shapes = [] can_push_any_alloc = False for inp, inp_sig in zip(node.inputs, op.inputs_sig): - if inp.owner and isinstance(inp.owner.op, Alloc): - # Push batch dims from Alloc - value, *shape = inp.owner.inputs - - # Check what to do with the value of the Alloc - squeezed_value = _squeeze_left(value, batch_ndim) - missing_ndim = len(shape) - value.type.ndim - if ( - (((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]) - != inp.type.broadcastable[batch_ndim:] - ): - # We still need an Alloc for the core dims - core_shape = shape[batch_ndim:] - # And the batch dims of the squeezed value - squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape) - batch_shape = [ - 1 if broadcastable else dim - for broadcastable, dim in zip( - squeezed_value.type.broadcastable[:squeezed_value_batch_ndim], - tuple(squeezed_value.shape)[:squeezed_value_batch_ndim], + if not all(inp.type.broadcastable[:batch_ndim]): + if inp.owner and isinstance(inp.owner.op, Alloc): + # Push batch dims from Alloc + value, *shape = inp.owner.inputs + + # Check what to do with the value of the Alloc + squeezed_value = _squeeze_left(value, batch_ndim) + missing_ndim = len(shape) - value.type.ndim + if ( + (((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]) + != inp.type.broadcastable[batch_ndim:] + ): + # We still need an Alloc for the core dims + core_shape = shape[batch_ndim:] + # And the batch dims of the squeezed value + squeezed_value_batch_ndim = squeezed_value.type.ndim - len( + core_shape ) - ] - squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape) - if squeezed_value.type.broadcastable == inp.type.broadcastable: - # We can't change anything about this Alloc input - new_inputs.append(inp) - continue - - # We can push batch dims of this Alloc input - batch_shapes.append( - tuple( - 1 if broadcastable else dim - for broadcastable, dim in zip( - inp.type.broadcastable, shape[:batch_ndim] + batch_shape = [ + 1 if broadcastable else dim + for broadcastable, dim in zip( + squeezed_value.type.broadcastable[ + :squeezed_value_batch_ndim + ], + tuple(squeezed_value.shape)[:squeezed_value_batch_ndim], + ) + ] + squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape) + if squeezed_value.type.broadcastable == inp.type.broadcastable: + # We can't change anything about this Alloc input + new_inputs.append(inp) + continue + + # We can push batch dims of this Alloc input + batch_shapes.append( + tuple( + 1 if broadcastable else dim + for broadcastable, dim in zip( + inp.type.broadcastable, shape[:batch_ndim] + ) ) ) - ) - new_inputs.append(squeezed_value) - can_push_any_alloc = True + new_inputs.append(squeezed_value) + can_push_any_alloc = True + continue - else: - # Nothing to do with this input other than removing dummy batch dims - new_inputs.append(_squeeze_left(inp, batch_ndim)) + # Nothing to do with this input other than removing dummy batch dims + new_inputs.append(_squeeze_left(inp, batch_ndim)) if not can_push_any_alloc: return None @@ -167,17 +179,15 @@ def local_blockwise_alloc(fgraph, node): missing_ndim = old_out_type.ndim - new_out_type.ndim batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim] for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples + if old_out_type.broadcastable[i]: + continue for batch_dim in batch_dims: if batch_dim == 1: continue + batch_shape[i] = batch_dim if isinstance(batch_dim, Constant): # Give preference to Constants - batch_shape[i] = batch_dim break - elif old_out_type.broadcastable[i]: - # Only use non Constant shapes if absolutely necessary - # Otherwise, we use the shape of the non-alloc output - batch_shape[i] = batch_dim copy_stack_trace(node.outputs, new_outs) new_outs = [ @@ -190,3 +200,28 @@ def local_blockwise_alloc(fgraph, node): ] copy_stack_trace(node.outputs, new_outs) return new_outs + + +@register_specialize +@node_rewriter([Blockwise]) +def local_blockwise_reshape(fgraph, node): + """Rewrite away square Blockwise reshapes. + + Reshape is tricky to vectorize eagerly, because a graph like + `x.reshape([x.shape[0] * x.shape[1], -1])` has many operations + that must be vectorized before we arrize at the reshape operation. + + For the square Reshape case, we must wait for all the intemediate + operations to be lifted as Allocs + """ + if not isinstance(node.op.core_op, Reshape): + return None + + x, output_shape = node.inputs + batch_ndim = node.op.batch_ndim(node) + if all(output_shape.type.broadcastable[:batch_ndim]): + batched_shape = x.shape[:batch_ndim] + core_reshape = _squeeze_left(output_shape, batch_ndim) + new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)]) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] diff --git a/pytensor/tensor/rewriting/einsum.py b/pytensor/tensor/rewriting/einsum.py new file mode 100644 index 0000000000..5e9fe2d026 --- /dev/null +++ b/pytensor/tensor/rewriting/einsum.py @@ -0,0 +1,53 @@ +from typing import cast + +from pytensor.graph import Apply, FunctionGraph, node_rewriter +from pytensor.graph.rewriting.basic import copy_stack_trace +from pytensor.tensor.einsum import Einsum, einsum +from pytensor.tensor.rewriting.basic import register_specialize +from pytensor.tensor.rewriting.ofg import inline_ofg_node +from pytensor.tensor.variable import TensorVariable + + +@register_specialize +@node_rewriter([Einsum]) +def optimize_einsum_inner_graph( + fgraph: FunctionGraph, node: Apply +) -> list[TensorVariable] | None: + """Try to optimize an einsum that was not optimizable at definition time. + + This can happen when users replace a graph without rebuilding + + Or when during the course of rewrites more specialized static shapes are found + """ + op: Einsum = node.op + + if op.optimized: + # Already optimized + return None + + operands = node.inputs + if any(None in operand.type.shape for operand in operands): + return None + + new_out = einsum(op.subscripts, *operands) + assert new_out.owner.op.optimized + + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + +@register_specialize +@node_rewriter([Einsum]) +def inline_optimized_einsum( + fgraph: FunctionGraph, node: Apply +) -> list[TensorVariable] | None: + """Inline einsums that are already optimized. + + This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right. + """ + op: Einsum = node.op + + if not op.optimized: + return None + + return cast(list[TensorVariable], inline_ofg_node(node)) diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 265f3ff2e8..2c4dfc4f70 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -1,12 +1,24 @@ -from pytensor import clone_replace +from typing import cast + +from pytensor import Variable, clone_replace from pytensor.compile import optdb from pytensor.compile.builders import OpFromGraph -from pytensor.graph import node_rewriter +from pytensor.graph import Apply, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, in2out from pytensor.tensor.basic import AllocDiag from pytensor.tensor.rewriting.basic import register_specialize +def inline_ofg_node(node: Apply) -> list[Variable]: + op = node.op + assert isinstance(op, OpFromGraph) + inlined_outs = clone_replace( + op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)) + ) + copy_stack_trace(op.inner_outputs, inlined_outs) + return cast(list[Variable], inlined_outs) + + @node_rewriter([OpFromGraph]) def inline_ofg_expansion(fgraph, node): """ @@ -18,10 +30,7 @@ def inline_ofg_expansion(fgraph, node): if not op.is_inline: return False - new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) - copy_stack_trace(op.inner_outputs, new_out) - - return new_out + return inline_ofg_node(node) # We want to run this before the first merge optimizer @@ -61,8 +70,4 @@ def late_inline_OpFromGraph(fgraph, node): ------- """ - op = node.op - new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) - copy_stack_trace(op.inner_outputs, new_out) - - return new_out + return inline_ofg_node(node) diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 1426a7d993..afa94d4e1f 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -749,51 +749,43 @@ def apply(self, fgraph): pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) -def local_reshape_chain(op): - @node_rewriter([op]) - def f(fgraph, node): - """ - Reshape(Reshape(shape1),shape2) -> Reshape(shape2) - - """ - if not check_chain(node, op, op): - return False - - # TODO: this can permit a failing program to run by eliminating - # the lower reshape - rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) - - # Copy over stacktrace from previous output node, as any error - # in new computational graph would have been caused by last op - # in the old computational graph. - copy_stack_trace(node.outputs, rval) - - # It might happen that the desired output of this node has a - # broadcastable pattern that does not match that of 'rval'. This is - # when originally, we were able to figure out that one of the - # dimensions of the reshape is one, but some other transformation - # replaced the shape by one for which this cannot be guessed. - # We should try to figure out why we lost the information about this - # constant value... but in the meantime, better not apply this - # rewrite. - if rval.type.ndim == node.outputs[0].type.ndim and all( - s1 == s2 - for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape) - if s1 == 1 or s2 == 1 - ): - return [rval] - else: - return False - - return f +@register_canonicalize("shape_unsafe") +@register_specialize("shape_unsafe") +@node_rewriter([Reshape]) +def local_reshape_chain(fgraph, node): + """ + Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2) + """ + if not check_chain(node, Reshape, Reshape): + return False -register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain") + rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) + + # Copy over stacktrace from previous output node, as any error + # in new computational graph would have been caused by last op + # in the old computational graph. + copy_stack_trace(node.outputs, rval) + + # It might happen that the desired output of this node has a + # broadcastable pattern that does not match that of 'rval'. This is + # when originally, we were able to figure out that one of the + # dimensions of the reshape is one, but some other transformation + # replaced the shape by one for which this cannot be guessed. + # We should try to figure out why we lost the information about this + # constant value... but in the meantime, better not apply this + # rewrite. + if rval.type.ndim == node.outputs[0].type.ndim and all( + s1 == s2 + for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape) + if s1 == 1 or s2 == 1 + ): + return [rval] -@register_useless -@register_canonicalize -@register_stabilize +@register_useless("shape_unsafe") +@register_canonicalize("shape_unsafe") +@register_specialize("shape_unsafe") @node_rewriter([Reshape]) def local_useless_reshape(fgraph, node): """Remove two kinds of useless `Reshape`. @@ -802,24 +794,17 @@ def local_useless_reshape(fgraph, node): - Remove `Reshape` when reshaping to the shape of the input. """ - inp = node.inputs[0] - output = node.outputs[0] - output_shape = node.inputs[1] + inp, output_shape = node.inputs + [output] = node.outputs if inp.type.ndim != output.type.ndim: return False # Simple case: both input and output have a single dimension. - # TODO FIXME XXX: This could hide errors if the user provides inconsistent - # shapes. if ( inp.type.ndim == 1 and output.type.ndim == 1 - and all( - s1 == s2 - for s1, s2 in zip(inp.type.shape, output.type.shape) - if s1 == 1 or s2 == 1 - ) + and inp.type.broadcastable == output.type.broadcastable ): return [inp] @@ -832,8 +817,15 @@ def local_useless_reshape(fgraph, node): # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for # broadcastable and constant dimensions - if output_shape.owner and isinstance(output_shape.owner.op, MakeVector): - output_shape_is = output_shape.owner.inputs + if isinstance(output_shape, Constant) or ( + output_shape.owner and isinstance(output_shape.owner.op, MakeVector) + ): + if isinstance(output_shape, Constant): + output_shape_is = [ + as_tensor_variable(dim, ndim=0) for dim in output_shape.data + ] + else: + output_shape_is = output_shape.owner.inputs shape_feature = getattr(fgraph, "shape_feature", None) @@ -865,9 +857,9 @@ def local_useless_reshape(fgraph, node): shape_match[dim] = True continue - # Match 1 if input.type.shape[dim] == 1 + # Match constant if input.type.shape[dim] == constant cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) - if inp.type.shape[dim] == 1 and cst_outshp_i == 1: + if inp.type.shape[dim] == cst_outshp_i: shape_match[dim] = True continue @@ -881,17 +873,18 @@ def local_useless_reshape(fgraph, node): if shape_feature: inpshp_i = shape_feature.get_shape(inp, dim) if inpshp_i == outshp_i or ( - extract_constant(inpshp_i, only_process_constants=1) - == extract_constant(outshp_i, only_process_constants=1) + extract_constant(inpshp_i, only_process_constants=True) + == extract_constant(outshp_i, only_process_constants=True) ): shape_match[dim] = True continue - if all(shape_match) and nb_m1 <= 1: + if nb_m1 <= 1 and all(shape_match): + return [inp] + + if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1): return [inp] - # TODO later: if all the shapes except one match, we may want to - # consider it useless as well, like we do in the 1-dim case. return False @@ -910,9 +903,8 @@ def local_reshape_to_dimshuffle(fgraph, node): -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) """ op = node.op - inp = node.inputs[0] - output = node.outputs[0] - output_shape = node.inputs[1] + inp, output_shape = node.inputs + [output] = node.outputs dimshuffle_new_order = [] new_output_shape = [] @@ -944,7 +936,7 @@ def local_reshape_to_dimshuffle(fgraph, node): @register_canonicalize -@register_stabilize +@register_specialize @node_rewriter([Reshape]) def local_reshape_lift(fgraph, node): """ diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 236c34b442..614258dcae 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -842,13 +842,13 @@ def c_code(self, node, name, inputs, outputs, sub): @_vectorize_node.register(Reshape) def _vectorize_reshape(op, node, x, shape): + from pytensor.tensor.blockwise import vectorize_node_fallback + old_x, old_shape = node.inputs batched_ndims = x.type.ndim - old_x.type.ndim if as_tensor_variable(shape).type.ndim != 1: - raise NotImplementedError( - "It is not possible to vectorize the shape argument of Reshape" - ) + return vectorize_node_fallback(op, node, x, shape) if len(tuple(old_shape)) == len(tuple(shape)): new_shape = [*x.shape[:batched_ndims], *shape] diff --git a/tests/link/jax/test_einsum.py b/tests/link/jax/test_einsum.py new file mode 100644 index 0000000000..9a55670c64 --- /dev/null +++ b/tests/link/jax/test_einsum.py @@ -0,0 +1,38 @@ +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt + + +jax = pytest.importorskip("jax") + + +def test_jax_einsum(): + subscripts = "ij, jk, kl -> il" + x = np.random.rand(3, 5) + y = np.random.rand(5, 2) + z = np.random.rand(2, 4) + + shapes = ((3, 5), (5, 2), (2, 4)) + x_pt, y_pt, z_pt = ( + pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes) + ) + out = pt.einsum(subscripts, x_pt, y_pt, z_pt) + f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX") + + np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z)) + + +@pytest.mark.xfail(raises=NotImplementedError) +def test_ellipsis_einsum(): + subscripts = "...i,...i->..." + x = np.random.rand(2, 5) + y = np.random.rand(2, 5) + + x_pt = pt.tensor("x", shape=x.shape) + y_pt = pt.tensor("y", shape=y.shape) + out = pt.einsum(subscripts, x_pt, y_pt) + f = pytensor.function([x_pt, y_pt], out, mode="JAX") + + np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y)) diff --git a/tests/tensor/rewriting/test_blockwise.py b/tests/tensor/rewriting/test_blockwise.py index d5ea6e2b4e..a17ad18a1f 100644 --- a/tests/tensor/rewriting/test_blockwise.py +++ b/tests/tensor/rewriting/test_blockwise.py @@ -1,7 +1,9 @@ from functools import partial -from pytensor import function -from pytensor.graph import FunctionGraph, rewrite_graph +import numpy as np + +from pytensor import Mode, config, function +from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph from pytensor.graph.basic import equal_computations from pytensor.scalar import log as scalar_log from pytensor.tensor import add, alloc, matrix, tensor, tensor3 @@ -9,6 +11,7 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.nlinalg import MatrixPinv from pytensor.tensor.rewriting.blockwise import local_useless_blockwise +from pytensor.tensor.shape import Reshape def test_useless_blockwise_of_elemwise(): @@ -45,7 +48,7 @@ def test_blockwise_alloc(): rewrite = partial( rewrite_graph, include=("ShapeOpt", "specialize"), - exclude=("local_useless_unbatched_blockwise",), + exclude=("local_useless_unbatched_blockwise", "local_dimshuffle_alloc"), ) vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)") @@ -104,7 +107,9 @@ def test_blockwise_alloc(): y = tensor("y", shape=()) out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5)) expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5) - assert equal([rewrite(out)], [expected_out]) + assert equal( + [rewrite(out)], [expected_out] + ), None # pytensor.dprint([expected_out, rewrite(out)], print_type=True) x = tensor("x", shape=(5,)) y = tensor("y", shape=()) @@ -118,3 +123,27 @@ def test_blockwise_alloc(): out = vector_add(x, alloc(y, 5)) expected_out = out assert equal([rewrite(out)], [expected_out]) + + +def test_blockwise_reshape(): + x = tensor("x", shape=(None, None, None)) + y = x.reshape([x.shape[0] * x.shape[1], -1]) + + new_x = tensor("x", shape=(None, None, None, None)) + new_y = vectorize_graph(y, {x: new_x}) + assert not isinstance(new_y.owner.op, Reshape) + assert isinstance(new_y.owner.op, Blockwise) and isinstance( + new_y.owner.op.core_op, Reshape + ) + + rewritten_y = rewrite_graph( + new_y, include=("canonicalize", "specialize"), clone=True + ) + assert isinstance(rewritten_y.owner.op, Reshape) + + no_rewrites = Mode(linker="py", optimizer=None) + test_x = np.arange(5 * 4 * 3 * 2).reshape(5, 4, 3, 2).astype(config.floatX) + np.testing.assert_allclose( + new_y.eval({"x": test_x}, mode=no_rewrites), + rewritten_y.eval({"x": test_x}, mode=no_rewrites), + ) diff --git a/tests/tensor/rewriting/test_einsum.py b/tests/tensor/rewriting/test_einsum.py new file mode 100644 index 0000000000..73e4372aaa --- /dev/null +++ b/tests/tensor/rewriting/test_einsum.py @@ -0,0 +1,39 @@ +from functools import partial + +from pytensor.graph import ancestors, rewrite_graph +from pytensor.tensor import einsum, specify_shape, tensor +from pytensor.tensor.einsum import Einsum + + +specialize_rewrite = partial(rewrite_graph, include=("specialize",), clone=True) + + +def test_einsum_optimization(): + a = tensor("a", shape=(None, None)) + b = tensor("b", shape=(None, None)) + c = tensor("c", shape=(None, None)) + + dynamic_shape_einsum = einsum("ij,ij,jk->ik", a, b, c) + assert not dynamic_shape_einsum.owner.op.optimized + + rewritten_out = specialize_rewrite(dynamic_shape_einsum) + assert isinstance(rewritten_out.owner.op, Einsum) + + a = specify_shape(a, (2, 3)) + b = specify_shape(b, (2, 3)) + c = specify_shape(c, (3, 5)) + + static_shape_einsum = dynamic_shape_einsum.owner.clone_with_new_inputs( + [a, b, c] + ).default_output() + assert not static_shape_einsum.owner.op.optimized + + rewritten_out = specialize_rewrite(static_shape_einsum) + # Einsum was inlined because it was optimized + assert not isinstance(rewritten_out.owner.op, Einsum) + # Sanity check that it's not buried in the graph + assert not any( + isinstance(var.owner.op, Einsum) + for var in ancestors([rewritten_out]) + if var.owner + ) diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index f4c529a0d2..bbfd829070 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -337,6 +337,52 @@ def test_m1(self): topo = f2.maker.fgraph.toposort() assert not any(isinstance(n.op, Reshape) for n in topo) + def test_constant_shape(self): + # Where reshape is a constant that matches the shape + x = matrix(shape=(2, 3)) + shape = pt.as_tensor(np.array([2, 3])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is x + + x = matrix(shape=(2, 3)) + shape = pt.as_tensor(np.array([-1, 3])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is x + + x = matrix(shape=(None, 3)) + shape = pt.as_tensor(np.array([-1, 3])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is x + + x = matrix(shape=(None, 3)) + shape = pt.as_tensor(np.array([2, 3])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + # This could be rewritten as a specify_shape(x, (2, 3)) + assert new_out is not x + + x = matrix(shape=(2, 3)) + shape = pt.as_tensor(np.array([3, 2])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is not x + + def test_all_but_one_match(self): + x = matrix(shape=(None, None)) + shape = [x.shape[0], 3] + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert equal_computations([new_out], [specify_shape(x, (None, 3))]) + + # Rewrite does not apply if there's also a -1 + shape = [-1, 3] + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is out + class TestLocalReshapeToDimshuffle: def setup_method(self): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 49c8e9c38c..58d4de2481 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3847,8 +3847,10 @@ def test_transpose(): assert np.all(t2d == np.transpose(x2v, [0, 1])) assert np.all(t3d == np.transpose(x3v, [0, 2, 1])) + # Check we don't introduce useless transpose + assert ptb.transpose(x1) is x1 + # Check that we create a name. - assert ptb.transpose(x1).name == "x1.T" assert ptb.transpose(x2).name == "x2.T" assert ptb.transpose(x3).name == "x3.T" assert ptb.transpose(dmatrix()).name is None diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py new file mode 100644 index 0000000000..749496e1ef --- /dev/null +++ b/tests/tensor/test_einsum.py @@ -0,0 +1,257 @@ +from functools import partial +from string import ascii_lowercase + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor import Mode, config, function +from pytensor.graph import FunctionGraph +from pytensor.graph.op import HasInnerGraph +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum +from pytensor.tensor.shape import Reshape + + +# Fail for unexpected warnings in this file +pytestmark = pytest.mark.filterwarnings("error") + +floatX = pytensor.config.floatX +ATOL = RTOL = 1e-8 if floatX == "float64" else 1e-4 + + +def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None: + for node in fgraph.apply_nodes: + if isinstance(node.op, Blockwise): + if core_op is None: + raise AssertionError + assert not isinstance(node.op.core_op, core_op) + + if isinstance(node.op, HasInnerGraph): + # InnerGraph Ops can be rewritten without modifying the original fgraph + if hasattr(node.op, "_fn"): + inner_fgraph = node.op._fn.maker.fgraph + else: + inner_fgraph = node.op.fgraph + assert_no_blockwise_in_graph(inner_fgraph, core_op=core_op) + + +def test_iota(): + mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + _iota((4, 8), 0).eval(mode=mode), + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2, 2, 2], + [3, 3, 3, 3, 3, 3, 3, 3], + ], + ) + + np.testing.assert_allclose( + _iota((4, 8), 1).eval(mode=mode), + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + ], + ) + + +def test_delta(): + mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + _delta((2, 2), (0, 1)).eval(mode=mode), + [[1.0, 0.0], [0.0, 1.0]], + ) + + +def test_general_dot(): + rng = np.random.default_rng(45) + signature = "(l0,a0,a1,l1),(a1,r0,r1,a0)->(l0,l1,r0,r1)" + tensordot_axes = [(-3, -2), (-1, -4)] + + # X has two batch dims + # Y has one batch dim + x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3)) + y = pt.tensor("y", shape=(4, 13, 5, 7, 11)) + out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)]) + + fn = pytensor.function([x, y], out) + # fn.dprint(print_type=True) + if config.mode != "FAST_COMPILE": + assert_no_blockwise_in_graph(fn.maker.fgraph, Reshape) + + np_batched_tensordot = np.vectorize( + partial(np.tensordot, axes=tensordot_axes), signature=signature + ) + x_test = rng.normal(size=x.type.shape).astype(floatX) + y_test = rng.normal(size=y.type.shape).astype(floatX) + np.testing.assert_allclose( + fn(x_test, y_test), np_batched_tensordot(x_test, y_test), atol=ATOL, rtol=RTOL + ) + + +@pytest.mark.parametrize("static_shape_known", [True, False]) +@pytest.mark.parametrize( + "signature", + [ + "ij", + "ji", + "ii->i", + "ii", + "ij->", + "ij->j", + "ij->i", + "ij,ij->ij", + "ij,ji->ij", + "ij,ji->ji", + "ij,jk", + "kj,ji", + "ij,kj->ik", + "ik,kj->ikj", + "ij,kl->ijkl", + "ij,jk,kl->il", + "kl,ij,jk->il", + "oij,imj,mjkn,lnk,plk->op", + ], +) +def test_einsum_signatures(static_shape_known, signature): + letters_to_dims = dict(zip("ijklmnop", [2, 3, 5, 7, 11, 13, 17, 19], strict=True)) + + inputs = signature.split("->")[0].split(",") + + shapes = [tuple(letters_to_dims[letter] for letter in inp) for inp in inputs] + if static_shape_known: + static_shapes = shapes + else: + static_shapes = [[None] * len(shape) for shape in shapes] + + operands = [ + pt.tensor(name, shape=static_shape) + for name, static_shape in zip(ascii_lowercase, static_shapes) + ] + out = pt.einsum(signature, *operands) + assert out.owner.op.optimized == static_shape_known or len(operands) <= 2 + + rng = np.random.default_rng(37) + test_values = [rng.normal(size=shape).astype(floatX) for shape in shapes] + np_out = np.einsum(signature, *test_values) + + fn = function(operands, out) + pt_out = fn(*test_values) + + # print(); fn.dprint(print_type=True) + + assert_no_blockwise_in_graph(fn.maker.fgraph) + np.testing.assert_allclose(pt_out, np_out, atol=ATOL, rtol=RTOL) + + +def test_batch_dim(): + shapes = ( + (7, 3, 5), + (5, 2), + ) + x, y = (pt.tensor(name, shape=shape) for name, shape in zip("xy", shapes)) + out = pt.einsum("mij,jk->mik", x, y) + + assert out.type.shape == (7, 3, 2) + + +def test_einsum_conv(): + # Adapted example from https://medium.com/latinxinai/vectorized-convolution-operation-using-numpy-b122fd52fba3 + rng = np.random.default_rng(125) + batch_size = 32 + channels = 3 + height = 8 + width = 8 + kernel_size = 2 + num_filters = 15 + conv_signature = "bchwkt,fckt->bfhw" + windowed_input = rng.random( + size=(batch_size, channels, height, width, kernel_size, kernel_size) + ).astype(floatX) + weights = rng.random(size=(num_filters, channels, kernel_size, kernel_size)).astype( + floatX + ) + result = einsum(conv_signature, windowed_input, weights).eval() + + assert result.shape == (32, 15, 8, 8) + np.testing.assert_allclose( + result, + np.einsum("bchwkt,fckt->bfhw", windowed_input, weights), + atol=ATOL, + rtol=RTOL, + ) + + +def test_ellipsis(): + rng = np.random.default_rng(159) + x = pt.tensor("x", shape=(3, 5, 7, 11)) + y = pt.tensor("y", shape=(3, 5, 11, 13)) + x_test = rng.normal(size=x.type.shape).astype(floatX) + y_test = rng.normal(size=y.type.shape).astype(floatX) + expected_out = np.matmul(x_test, y_test) + + with pytest.raises(ValueError): + pt.einsum("mp,pn->mn", x, y) + + out = pt.einsum("...mp,...pn->...mn", x, y) + np.testing.assert_allclose( + out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL + ) + + # Put batch axes in the middle + new_x = pt.moveaxis(x, -2, 0) + new_y = pt.moveaxis(y, -2, 0) + out = pt.einsum("m...p,p...n->m...n", new_x, new_y) + np.testing.assert_allclose( + out.eval({x: x_test, y: y_test}), + expected_out.transpose(-2, 0, 1, -1), + atol=ATOL, + rtol=RTOL, + ) + + out = pt.einsum("m...p,p...n->mn", new_x, new_y) + np.testing.assert_allclose( + out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL + ) + + +def test_broadcastable_dims(): + # Test that einsum handles broadcasting dims correctly. There are two points: + # 1. Numpy einsum allows the same subscript for degenerate and full dimensions + # There is some stale discussion on whether this should be a bug or not, but for now it is not: + # https://github.com/numpy/numpy/issues/11548 + + # 2. Using the same letter for dimensions that are and aren't broadcastable + # can lead to suboptimal paths. We check we issue a warning for the following example: + # https://github.com/dgasmith/opt_einsum/issues/220 + rng = np.random.default_rng(222) + a = pt.tensor("a", shape=(32, 32, 32)) + b = pt.tensor("b", shape=(1000, 32)) + c = pt.tensor("c", shape=(1, 32)) + + a_test = rng.normal(size=a.type.shape).astype(floatX) + b_test = rng.normal(size=b.type.shape).astype(floatX) + c_test = rng.normal(size=c.type.shape).astype(floatX) + + # Note b is used for both 1 and 32 + with pytest.warns( + UserWarning, match="This can result in a suboptimal contraction path" + ): + suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c) + assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}] + + # If we use a distinct letter we get the optimal path + optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c) + assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}] + + suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test}) + optimal_eval = optimal_out.eval({a: a_test, b: b_test, c: c_test}) + np_eval = np.einsum("ijk,bj,bk->i", a_test, b_test, c_test) + atol = 1e-12 if config.floatX == "float64" else 1e-2 + np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol) + np.testing.assert_allclose(optimal_eval, np_eval, atol=atol) diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 7fa8133c4e..f9434c9f60 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -14,7 +14,7 @@ from pytensor.misc.safe_asarray import _asarray from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row -from pytensor.tensor.basic import MakeVector, as_tensor, constant +from pytensor.tensor.basic import MakeVector, constant, stack from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.shape import ( @@ -801,8 +801,14 @@ def test_reshape(self): [vect_out] = vectorize_node(node, mat, new_shape).outputs assert equal_computations([vect_out], [reshape(mat, new_shape)]) - with pytest.raises(NotImplementedError): - vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3))) + new_shape = stack([[-1, x], [x - 1, -1]], axis=0) + print(new_shape.type) + [vect_out] = vectorize_node(node, vec, new_shape).outputs + vec_test_value = np.arange(6) + np.testing.assert_allclose( + vect_out.eval({x: 3, vec: vec_test_value}), + np.broadcast_to(vec_test_value.reshape(2, 3), (2, 2, 3)), + ) with pytest.raises( ValueError,