Skip to content

Commit

Permalink
Make repeated indexes work
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 7, 2024
1 parent 3e958ce commit 180ef9d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.tensor.basic import (
arange,
expand_dims,
get_vector_length,
stack,
transpose,
Expand Down Expand Up @@ -36,9 +35,10 @@ def __init__(


def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
axis = normalize_axis_index(axis, get_vector_length(shape))
len_shape = get_vector_length(shape)
axis = normalize_axis_index(axis, len_shape)
values = arange(shape[axis])
return broadcast_to(shape_padright(values, axis), shape)
return broadcast_to(shape_padright(values, len_shape - axis - 1), shape)


def _delta(shape, axes: Sequence[int]) -> TensorVariable:
Expand All @@ -47,7 +47,7 @@ def _delta(shape, axes: Sequence[int]) -> TensorVariable:
iotas = [_iota(base_shape, i) for i in range(len(axes))]
eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)]
result = reduce(and_, eyes)
return broadcast_to(expand_dims(result, tuple(axes)), shape)
return broadcast_to(result, shape)


def _removechars(s, chars):
Expand Down
33 changes: 33 additions & 0 deletions tests/tensor/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,39 @@
import pytest

import pytensor.tensor as pt
from pytensor import Mode
from pytensor.tensor.einsum import _delta, _iota


def test_iota():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
_iota((4, 8), 0).eval(mode=mode),
[
[0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3],
],
)

np.testing.assert_allclose(
_iota((4, 8), 1).eval(mode=mode),
[
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
],
)


def test_delta():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
_delta((2, 2), (0, 1)).eval(mode=mode),
[[1.0, 0.0], [0.0, 1.0]],
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 180ef9d

Please sign in to comment.