From 70cc6a3213ea60ce704b458bd32a2940ae64ecef Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 4 Aug 2024 11:55:13 +0200 Subject: [PATCH] Fix einsum doctest --- pytensor/tensor/einsum.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index ab0c399f8c..dab2520988 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -86,28 +86,30 @@ def _iota(shape: TensorVariable, axis: int) -> TensorVariable: .. testcode:: - import pytensor as pt - shape = pt.as_tensor('shape', (5,)) - print(pt._iota(shape, 0).eval()) + import pytensor.tensor as pt + from pytensor.tensor.einsum import _iota + + shape = pt.as_tensor((5,)) + print(_iota(shape, 0).eval()) .. testoutput:: - [0., 1., 2., 3., 4.] + [0 1 2 3 4] - In higher dimensions, it will look like many concatenated `pt.arange`: + In higher dimensions, it will look like many concatenated `arange`: .. testcode:: - shape = pt.as_tensor('shape', (5, 5)) - print(pt._iota(shape, 1).eval()) + shape = pt.as_tensor((5, 5)) + print(_iota(shape, 1).eval()) .. testoutput:: - [[0., 1., 2., 3., 4.], - [0., 1., 2., 3., 4.], - [0., 1., 2., 3., 4.], - [0., 1., 2., 3., 4.], - [0., 1., 2., 3., 4.]] + [[0 1 2 3 4] + [0 1 2 3 4] + [0 1 2 3 4] + [0 1 2 3 4] + [0 1 2 3 4]] Setting ``axis=0`` above would result in the transpose of the output. """ @@ -218,14 +220,14 @@ def _general_dot( from pytensor.tensor.einsum import _general_dot import numpy as np - A = pt.tensor(shape = (3, 4, 5)) - B = pt.tensor(shape = (3, 5, 2)) + A = pt.tensor(shape=(3, 4, 5)) + B = pt.tensor(shape=(3, 5, 2)) result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]]) A_val = np.empty((3, 4, 5)) B_val = np.empty((3, 5, 2)) - print(result.shape.eval({A:A_val, B:B_val})) + print(tuple(result.shape.eval({A:A_val, B:B_val}))) .. testoutput::