diff --git a/pyttb/tensor.py b/pyttb/tensor.py index e70dcd19..340786c8 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -942,6 +942,48 @@ def ttm(self, matrix, dims=None, transpose=False): Y = np.transpose(Y, np.argsort(order)) return ttb.tensor.from_data(Y) + def ttt(self, other, selfdims=None, otherdims=None): + """ + Tensor mulitplication (tensor times tensor) + + Parameters + ---------- + other: :class:`ttb.tensor` + selfdims: :class:`Numpy.ndarray`, int + otherdims: :class:`Numpy.ndarray`, int + """ + + if not isinstance(other, tensor): + assert False, "other must be of type tensor" + + if selfdims is None: + selfdims = np.array([]) + selfshape = () + else: + selfshape = tuple(np.array(self.shape)[selfdims]) + + if otherdims is None: + otherdims = selfdims.copy() + othershape = () + else: + othershape = tuple(np.array(other.shape)[otherdims]) + + if not selfshape == othershape: + assert False, "Specified dimensions do not match" + + # Compute the product + + # Avoid transpose by reshaping self and computing result = self * other + amatrix = ttb.tenmat.from_tensor_type(self, cdims=selfdims) + bmatrix = ttb.tenmat.from_tensor_type(other, rdims=otherdims) + cmatrix = amatrix * bmatrix + + # Check whether or not the result is a scalar + if isinstance(cmatrix, ttb.tenmat): + return ttb.tensor.from_tensor_type(cmatrix) + else: + return cmatrix + def ttv(self, vector, dims=None): """ Tensor times vector diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 7497f352..c0c16f2d 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1057,6 +1057,24 @@ def test_tensor_ttm(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way): tensorInstance3.ttm(M2, tensorInstance3.ndims + 1) assert "dims must contain values in [0,self.dims]" in str(excinfo) +@pytest.mark.indevelopment +def test_tensor_ttt(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way): + + M31 = ttb.tensor.from_data(np.reshape(np.arange(1,2*3*4+1),[4,3,2], order='F')) + M32 = ttb.tensor.from_data(np.reshape(np.arange(1,2*3*4+1),[3,4,2], order='F')) + + # outer product of M31 and M32 + TTT1 = M31.ttt(M32) + assert TTT1.shape == (4,3,2,3,4,2) + # choose two random 2-way slices + data11 = np.array([1,2,3,4]) + data12 = np.array([289,306,323,340]) + data13 = np.array([504,528,552,576]) + assert (TTT1[:,0,0,0,0,0].data == data11).all() + assert (TTT1[:,1,1,1,1,1].data == data12).all() + assert (TTT1[:,2,1,2,3,1].data == data13).all() + + @pytest.mark.indevelopment def test_tensor_ttv(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way): (params2, tensorInstance2) = sample_tensor_2way