Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jul 7, 2024
1 parent 5e83d26 commit 23a5e00
Showing 1 changed file with 157 additions and 11 deletions.
168 changes: 157 additions & 11 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,42 @@ def _general_dot(
def contraction_list_from_path(
subscripts: str, operands: Sequence[TensorLike], path: PATH
):
"""TODO Docstrings
"""
Generate a list of contraction steps based on the provided einsum path.
Code adapted from einsum_opt: https://github.com/dgasmith/opt_einsum/blob/94c62a05d5ebcedd30f59c90b9926de967ed10b5/opt_einsum/contract.py#L369
When all shapes are known, the linked einsum_opt implementation is preferred. This implementation is used when
some or all shapes are not known. As a result, contraction will (always?) be done left-to-right, pushing intermediate
results to the end of the stack.
Code adapted from einsum_opt
Parameters
----------
subscripts: str
Einsum signature string describing the computation to be performed.
operands: Sequence[TensorLike]
Tensors described by the subscripts.
path: tuple[tuple[int] | tuple[int, int]]
A list of tuples, where each tuple describes the indices of the operands to be contracted, sorted in the order
they should be contracted.
Returns
-------
contraction_list: list
A list of tuples, where each tuple describes a contraction step. Each tuple contains the following elements:
- contraction_inds: tuple[int]
The indices of the operands to be contracted
- idx_removed: str
The indices of the contracted indices (those removed from the einsum string at this step)
- einsum_str: str
The einsum string for the contraction step
- remaining: None
The remaining indices. Included to match the output of opt_einsum.contract_path, but not used.
- do_blas: None
Whether to use blas to perform this step. Included to match the output of opt_einsum.contract_path,
but not used.
"""
fake_operands = [
np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands
Expand Down Expand Up @@ -199,9 +232,13 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
Code adapted from JAX: https://github.com/google/jax/blob/534d32a24d7e1efdef206188bb11ae48e9097092/jax/_src/numpy/lax_numpy.py#L5283
Einsum allows the user to specify a wide range of operations on tensors using the Einstein summation convention. Using
this notation, many common linear algebraic operations can be succinctly described on higher order tensors.
Parameters
----------
subscripts: str
Einsum signature string describing the computation to be performed.
operands: sequence of TensorVariable
Tensors to be multiplied and summed.
Expand All @@ -210,7 +247,110 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
-------
TensorVariable
The result of the einsum operation.
See Also
--------
pytensor.tensor.tensordot: Generalized dot product between two tensors
pytensor.tensor.dot: Matrix multiplication between two tensors
numpy.einsum: The numpy implementation of einsum
Examples
--------
Inputs to `pt.einsum` are a string describing the operation to be performed (the "subscripts"), and a sequence of
tensors to be operated on. The string must follow the following rules:
1. The string gives inputs and (optionally) outputs. Inputs and outputs are separated by "->".
2. The input side of the string is a comma-separated list of indices. For each comma-separated index string, there
must be a corresponding tensor in the input sequence.
3. For each index string, the number of dimensions in the corresponding tensor must match the number of characters
in the index string.
4. Indices are arbitrary strings of characters. If an index appears multiple times in the input side, it must have
the same shape in each input.
5. The indices on the output side must be a subset of the indices on the input side -- you cannot introduce new
indices in the output.
6. Elipses ("...") can be used to elide multiple indices. This is useful when you have a large number of "batch"
dimensions that are not implicated in the operation.
Finally, two rules about these indicies govern how computation is carried out:
1. Repeated indices on the input side indicate how the tensor should be "aligned" for multiplication.
2. Indices that appear on the input side but not the output side are summed over.
The operation of these rules is best understood via examples:
Example 1: Matrix multiplication
.. code-block:: python
import pytensor as pt
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.einsum("ij, jk -> ik", A, B)
This computation is equivalent to :code:`C = A @ B`. Notice that the ``j`` index is repeated on the input side of the
signature, and does not appear on the output side. This indicates that the ``j`` dimension of the first tensor should be
multiplied with the ``j`` dimension of the second tensor, and the resulting tensor's ``j`` dimension should be summed
away.
Example 2: Batched matrix multiplication
.. code-block:: python
import pytensor as pt
A = pt.tensor("A", shape=(None, 4, 5))
B = pt.tensor("B", shape=(None, 5, 6))
C = pt.einsum("bij, bjk -> bik", A, B)
This computation is also equivalent to :code:`C = A @ B` because of Pytensor's built-in broadcasting rules, but
the einsum signature is more explicit about the batch dimensions. The ``b`` and ``j`` indices are repeated on the
input side. Unlike ``j``, the ``b`` index is also present on the output side, indicating that the batch dimension
should **not** be summed away. As a result, multiplication will be performed over the ``b, j`` dimensions, and then
the ``j`` dimension will be summed over. The resulting tensor will have shape ``(None, 4, 6)``.
Example 3: Batched matrix multiplication with elipses
.. code-block:: python
import pytensor as pt
A = pt.tensor("A", shape=(4, None, None, None, 5))
B = pt.tensor("B", shape=(5, None, None, None, 6))
C = pt.einsum("i...j, j...k -> ...ik", A, B)
This case is the same as above, but inputs ``A`` and ``B`` have multiple batch dimensions. To avoid writing out all
of the batch dimensions (which we do not care about), we can use ellipses to elide over these dimensions. Notice
also that we are not required to "sort" the input dimensions in any way. In this example, we are doing a dot
between the last dimension A and the first dimension of B, which is perfectly valid.
Example 4: Outer product
.. code-block:: python
import pytensor as pt
x = pt.tensor("x", shape=(3,))
y = pt.tensor("y", shape=(4,))
z = pt.einsum("i, j -> ij", x, y)
This computation is equivalent to :code:`pt.outer(x, y)`. Notice that no indices are repeated on the input side,
and the output side has two indices. Since there are no indices to align on, the einsum operation will simply
multiply the two tensors elementwise, broadcasting dimensions ``i`` and ``j``.
Example 5: Convolution
.. code-block:: python
import pytensor as pt
x = pt.tensor("x", shape=(None, None, None, None, None, None))
w = pt.tensor("w", shape=(None, None, None, None))
y = pt.einsum(""bchwkt,fckt->bfhw", x, w)
Given a batch of images ``x`` with dimensions ``(batch, channel, height, width, kernel_size, num_filters)``
and a filter ``w``, with dimensions ``(num_filters, channels, kernel_size, num_filters)``, this einsum operation
computes the convolution of ``x`` with ``w``. Multiplication is aligned on the batch, num_filters, height, and width
dimensions. The channel, kernel_size, and num_filters dimensions are summed over. The resulting tensor has shape
``(batch, num_filters, height, width)``, reflecting the fact that information from each channel has been mixed
together.
"""

# TODO: Is this doing something clever about unknown shapes?
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
# using einsum_call=True here is an internal api for opt_einsum... sorry
Expand All @@ -223,21 +363,24 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
shapes = [operand.type.shape for operand in operands]

if None in itertools.chain.from_iterable(shapes):
# We mark optimize = False, even in cases where there is no ordering optimization to be done
# because the inner graph may have to accommodate dynamic shapes.
# If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
# Case 1: At least one of the operands has an unknown shape. In this case, we can't use opt_einsum to optimize
# the contraction order, so we just use a default path of (1,0) contractions. This will work left-to-right,
# pushing intermediate results to the end of the stack.
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will
# match more often

# If shapes become known later we will likely want to rebuild the Op (unless we inline it)
if len(operands) == 1:
path = [(0,)]
else:
# Create default path of repeating (1,0) that executes left to right cyclically
# with intermediate outputs being pushed to the end of the stack
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
path = [(1, 0) for i in range(len(operands) - 1)]
contraction_list = contraction_list_from_path(subscripts, operands, path)
optimize = (
len(operands) <= 2
) # If there are only 1 or 2 operands, there is no optimization to be done?

# If there are only 1 or 2 operands, there is no optimization to be done?
optimize = len(operands) <= 2
else:
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
# contraction order.
_, contraction_list = contract_path(
subscripts,
*shapes,
Expand All @@ -252,6 +395,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
def sum_uniques(
operand: TensorVariable, names: str, uniques: list[str]
) -> tuple[TensorVariable, str]:
"""Reduce unique indices (those that appear only once) in a given contraction step via summing."""
if uniques:
axes = [names.index(name) for name in uniques]
operand = operand.sum(axes)
Expand All @@ -264,6 +408,8 @@ def sum_repeats(
counts: collections.Counter,
keep_names: str,
) -> tuple[TensorVariable, str]:
"""Reduce repeated indices in a given contraction step via summation against an identity matrix."""

for name, count in counts.items():
if count > 1:
axes = [i for i, n in enumerate(names) if n == name]
Expand Down

0 comments on commit 23a5e00

Please sign in to comment.