Skip to content

Commit

Permalink
Implement Einsum
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Seyboldt <[email protected]>
Co-authored-by: Jesse Grabowski <[email protected]>
Co-authored-by: Rob Zinkov <[email protected]>
  • Loading branch information
4 people authored and ricardoV94 committed Jul 19, 2024
1 parent 981688c commit 52575e8
Show file tree
Hide file tree
Showing 18 changed files with 1,361 additions and 144 deletions.
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Load dispatch specializations
import pytensor.link.jax.dispatch.blas
import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.einsum
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.pad
Expand Down
20 changes: 20 additions & 0 deletions pytensor/link/jax/dispatch/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import jax.numpy as jnp

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.einsum import Einsum


@jax_funcify.register(Einsum)
def jax_funcify_Einsum(op, **kwargs):
"""Dispatch einsum to JAX.
This dispatch is triggered only when we couldn't optimize einsum at the PyTensor level.
This happens when some of the dimension lengths are unknown. This is never a problem in JAX,
as it always compiles a function per runtime input shape.
"""
subscripts = op.subscripts

def einsum(*operands):
return jnp.einsum(subscripts, *operands, optimize="optimal")

return einsum
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:


# isort: off
from pytensor.tensor.einsum import einsum
from pytensor.tensor.functional import vectorize
# isort: on

Expand Down
32 changes: 20 additions & 12 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,21 +1700,22 @@ def do_constant_folding(self, fgraph, node):
return False

for client, idx in clients:
if isinstance(client.op, Output):
client_op = client.op
if isinstance(client_op, Output):
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
return False
# Allow alloc to be lifted out of Elemwise before constant folding it
elif isinstance(client.op, Elemwise):
return None
# Op's through which Alloc can be lifted
elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join):
return False
# Same for Blockwise, unless it has no batch_dims
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client):
return None
elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client):
return False
elif (
# The following ops work inplace of their input id 0.
idx == 0
and isinstance(
client.op,
client_op,
pytensor.tensor.subtensor.IncSubtensor
| pytensor.tensor.subtensor.AdvancedIncSubtensor1
| pytensor.tensor.subtensor.AdvancedIncSubtensor
Expand Down Expand Up @@ -2035,7 +2036,12 @@ def transpose(x, axes=None):
_x = as_tensor_variable(x)

if axes is None:
axes = list(range((_x.type.ndim - 1), -1, -1))
axes = tuple(range((_x.type.ndim - 1), -1, -1))

if tuple(axes) == tuple(range(len(axes))):
# No-op
return _x

ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)

if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
Expand Down Expand Up @@ -3950,6 +3956,10 @@ def moveaxis(
source = normalize_axis_tuple(source, a.ndim, "source")
destination = normalize_axis_tuple(destination, a.ndim, "destination")

if source == destination:
# It's a no-op
return a

if len(source) != len(destination):
raise ValueError(
"`source` and `destination` arguments must have the same number of elements"
Expand Down Expand Up @@ -4260,9 +4270,7 @@ def atleast_Nd(
atleast_3d = partial(atleast_Nd, n=3)


def expand_dims(
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
) -> TensorVariable:
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
"""Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
Expand All @@ -4281,7 +4289,7 @@ def expand_dims(
"""
a = as_tensor(a)

if not isinstance(axis, tuple | list):
if not isinstance(axis, Sequence):
axis = (axis,)

out_ndim = len(axis) + a.ndim
Expand Down
Loading

0 comments on commit 52575e8

Please sign in to comment.