Skip to content

Commit

Permalink
Adding tensor.ttt implementation. (#44)
Browse files Browse the repository at this point in the history
Closes 28
  • Loading branch information
dmdunla committed Jul 18, 2022
1 parent 2ab1934 commit 16e57a7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
42 changes: 42 additions & 0 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,48 @@ def ttm(self, matrix, dims=None, transpose=False):
Y = np.transpose(Y, np.argsort(order))
return ttb.tensor.from_data(Y)

def ttt(self, other, selfdims=None, otherdims=None):
"""
Tensor mulitplication (tensor times tensor)
Parameters
----------
other: :class:`ttb.tensor`
selfdims: :class:`Numpy.ndarray`, int
otherdims: :class:`Numpy.ndarray`, int
"""

if not isinstance(other, tensor):
assert False, "other must be of type tensor"

if selfdims is None:
selfdims = np.array([])
selfshape = ()
else:
selfshape = tuple(np.array(self.shape)[selfdims])

if otherdims is None:
otherdims = selfdims.copy()
othershape = ()
else:
othershape = tuple(np.array(other.shape)[otherdims])

if not selfshape == othershape:
assert False, "Specified dimensions do not match"

# Compute the product

# Avoid transpose by reshaping self and computing result = self * other
amatrix = ttb.tenmat.from_tensor_type(self, cdims=selfdims)
bmatrix = ttb.tenmat.from_tensor_type(other, rdims=otherdims)
cmatrix = amatrix * bmatrix

# Check whether or not the result is a scalar
if isinstance(cmatrix, ttb.tenmat):
return ttb.tensor.from_tensor_type(cmatrix)
else:
return cmatrix

def ttv(self, vector, dims=None):
"""
Tensor times vector
Expand Down
18 changes: 18 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,24 @@ def test_tensor_ttm(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
tensorInstance3.ttm(M2, tensorInstance3.ndims + 1)
assert "dims must contain values in [0,self.dims]" in str(excinfo)

@pytest.mark.indevelopment
def test_tensor_ttt(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):

M31 = ttb.tensor.from_data(np.reshape(np.arange(1,2*3*4+1),[4,3,2], order='F'))
M32 = ttb.tensor.from_data(np.reshape(np.arange(1,2*3*4+1),[3,4,2], order='F'))

# outer product of M31 and M32
TTT1 = M31.ttt(M32)
assert TTT1.shape == (4,3,2,3,4,2)
# choose two random 2-way slices
data11 = np.array([1,2,3,4])
data12 = np.array([289,306,323,340])
data13 = np.array([504,528,552,576])
assert (TTT1[:,0,0,0,0,0].data == data11).all()
assert (TTT1[:,1,1,1,1,1].data == data12).all()
assert (TTT1[:,2,1,2,3,1].data == data13).all()


@pytest.mark.indevelopment
def test_tensor_ttv(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
(params2, tensorInstance2) = sample_tensor_2way
Expand Down

0 comments on commit 16e57a7

Please sign in to comment.