Skip to content

Commit

Permalink
Fix einsum doctest
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 4, 2024
1 parent 2413d99 commit 70cc6a3
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,30 @@ def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
.. testcode::
import pytensor as pt
shape = pt.as_tensor('shape', (5,))
print(pt._iota(shape, 0).eval())
import pytensor.tensor as pt
from pytensor.tensor.einsum import _iota
shape = pt.as_tensor((5,))
print(_iota(shape, 0).eval())
.. testoutput::
[0., 1., 2., 3., 4.]
[0 1 2 3 4]
In higher dimensions, it will look like many concatenated `pt.arange`:
In higher dimensions, it will look like many concatenated `arange`:
.. testcode::
shape = pt.as_tensor('shape', (5, 5))
print(pt._iota(shape, 1).eval())
shape = pt.as_tensor((5, 5))
print(_iota(shape, 1).eval())
.. testoutput::
[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]
[[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]]
Setting ``axis=0`` above would result in the transpose of the output.
"""
Expand Down Expand Up @@ -218,14 +220,14 @@ def _general_dot(
from pytensor.tensor.einsum import _general_dot
import numpy as np
A = pt.tensor(shape = (3, 4, 5))
B = pt.tensor(shape = (3, 5, 2))
A = pt.tensor(shape=(3, 4, 5))
B = pt.tensor(shape=(3, 5, 2))
result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]])
A_val = np.empty((3, 4, 5))
B_val = np.empty((3, 5, 2))
print(result.shape.eval({A:A_val, B:B_val}))
print(tuple(result.shape.eval({A:A_val, B:B_val})))
.. testoutput::
Expand Down

0 comments on commit 70cc6a3

Please sign in to comment.