diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 5082d7a45e..d7966269c0 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -8,7 +8,6 @@ from pytensor.compile.builders import OpFromGraph from pytensor.tensor.basic import ( arange, - expand_dims, get_vector_length, stack, transpose, @@ -36,9 +35,10 @@ def __init__( def _iota(shape: TensorVariable, axis: int) -> TensorVariable: - axis = normalize_axis_index(axis, get_vector_length(shape)) + len_shape = get_vector_length(shape) + axis = normalize_axis_index(axis, len_shape) values = arange(shape[axis]) - return broadcast_to(shape_padright(values, axis), shape) + return broadcast_to(shape_padright(values, len_shape - axis - 1), shape) def _delta(shape, axes: Sequence[int]) -> TensorVariable: @@ -47,7 +47,7 @@ def _delta(shape, axes: Sequence[int]) -> TensorVariable: iotas = [_iota(base_shape, i) for i in range(len(axes))] eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)] result = reduce(and_, eyes) - return broadcast_to(expand_dims(result, tuple(axes)), shape) + return broadcast_to(result, shape) def _removechars(s, chars): diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index b813f2b3b3..95ea21c11b 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -4,6 +4,39 @@ import pytest import pytensor.tensor as pt +from pytensor import Mode +from pytensor.tensor.einsum import _delta, _iota + + +def test_iota(): + mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + _iota((4, 8), 0).eval(mode=mode), + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2, 2, 2], + [3, 3, 3, 3, 3, 3, 3, 3], + ], + ) + + np.testing.assert_allclose( + _iota((4, 8), 1).eval(mode=mode), + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + ], + ) + + +def test_delta(): + mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + _delta((2, 2), (0, 1)).eval(mode=mode), + [[1.0, 0.0], [0.0, 1.0]], + ) @pytest.mark.parametrize(