Skip to content

Commit

Permalink
Adding tensor.ttm. Adding use case in tenmat to support ttm testing. (#…
Browse files Browse the repository at this point in the history
…40)

Closes #27
  • Loading branch information
dmdunla authored Jul 16, 2022
1 parent cd32cb7 commit 3b3abcc
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 1 deletion.
6 changes: 5 additions & 1 deletion pyttb/tenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, *args):
self.data = np.array([])

@classmethod
def from_data(cls, data, rdims, cdims, tshape=None):
def from_data(cls, data, rdims, cdims=None, tshape=None):
# CONVERT A MULTIDIMENSIONAL ARRAY

# Verify that data is a numeric numpy.ndarray
Expand All @@ -47,6 +47,10 @@ def from_data(cls, data, rdims, cdims, tshape=None):
# make data a 2d array with shape (1, data.shape[0]), i.e., a row vector
data = np.reshape(data.copy(), (1, data.shape[0]), order='F')

# data is ndarray and only rdims is specified
if cdims is None:
return ttb.tenmat.from_tensor_type(ttb.tensor.from_data(data), rdims)

# use data.shape for tshape if not provided
if tshape is None:
tshape = data.shape
Expand Down
52 changes: 52 additions & 0 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,58 @@ def symmetrize(self, grps=None, version=None):

return Y

def ttm(self, matrix, dims=None, transpose=False):
"""
Tensor times matrix
Parameters
----------
matrix: :class:`Numpy.ndarray`, list[:class:`Numpy.ndarray`]
dims: :class:`Numpy.ndarray`, int
transpose: boolean
"""

if dims is None:
dims = np.arange(self.ndims)
elif isinstance(dims, list):
dims = np.array(dims)
elif np.isscalar(dims) or isinstance(dims, list):
dims = np.array([dims])

if isinstance(matrix, list):
# Check that the dimensions are valid
dims, vidx = ttb.tt_dimscheck(dims, self.ndims, len(matrix))
# Calculate individual products
Y = self.ttm(matrix[vidx[0]], dims[0], transpose)
for k in range(1,dims.size):
Y = Y.ttm(matrix[vidx[k]], dims[k], transpose)
return Y

if not isinstance(matrix, np.ndarray):
assert False, "matrix must be of type numpy.ndarray"

if not (dims.size == 1 and np.isin(dims, np.arange(self.ndims))):
assert False, "dims must contain values in [0,self.dims]"

# old version (ver=0)
shape = np.array(self.shape)
n = dims[0]
order = np.array([n] + list(range(0,n)) + list(range(n+1,self.ndims)))
newdata = self.permute(order)
ids = np.array(list(range(0,n)) + list(range(n+1,self.ndims)))
newdata = np.reshape(newdata.data, (shape[n],np.prod(shape[ids])), order="F")
if transpose:
newdata = matrix.T @ newdata
p = matrix.shape[1]
else:
newdata = matrix @ newdata
p = matrix.shape[0]

newshape = np.array([p] + list(shape[range(0,n)]) + list(shape[range(n+1,self.ndims)]))
Y = np.reshape(newdata, newshape, order="F")
Y = np.transpose(Y, np.argsort(order))
return ttb.tensor.from_data(Y)

def ttv(self, vector, dims=None):
"""
Tensor times vector
Expand Down
4 changes: 4 additions & 0 deletions tests/test_tenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def test_tenmat_initialization_from_data(sample_ndarray_1way, sample_ndarray_2wa
assert tenmatNdarray4.shape == tenmatInstance.shape
assert tenmatNdarray4.tshape == tenmatInstance.tshape

## Constructor from 4d array just specifying rdims
tenmatNdarray4 = ttb.tenmat.from_data(ndarrayInstance4, np.array([0]))
assert (tenmatNdarray4.data == np.reshape(ndarrayInstance4, tenmatNdarray4.shape, order='F')).all()

# Exceptions

## data is not numpy.ndarray
Expand Down
80 changes: 80 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,86 @@ def test_tensor_squeeze(sample_tensor_2way):
# A singleton dimension
assert (ttb.tensor.from_data(np.array([[1, 2, 3]])).squeeze().data == np.array([1, 2, 3])).all()

@pytest.mark.indevelopment
def test_tensor_ttm(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
(params2, tensorInstance2) = sample_tensor_2way
(params3, tensorInstance3) = sample_tensor_3way
(params4, tensorInstance4) = sample_tensor_4way

M2 = np.reshape(np.arange(1,2*2+1),[2, 2], order='F')
M3 = np.reshape(np.arange(1,3*3+1),[3, 3], order='F')

# 3-way single matrix
T3 = tensorInstance3.ttm(M2, 0)
assert isinstance(T3, ttb.tensor)
assert T3.shape == (2,3,2)
data3 = np.array([[[ 7, 31],
[15, 39],
[23, 47]],
[[10, 46],
[22, 58],
[34, 70]]])
assert (T3.data == data3).all()

# 3-way single matrix, transposed
T3 = tensorInstance3.ttm(M2, 0, transpose=True)
assert isinstance(T3, ttb.tensor)
assert T3.shape == (2,3,2)
data3 = np.array([[[ 5, 23],
[11, 29],
[17, 35]],
[[11, 53],
[25, 67],
[39, 81]]])
assert (T3.data == data3).all()

# 3-way, two matrices, negative dimension
T3 = tensorInstance3.ttm([M2, M2], -2)
assert isinstance(T3, ttb.tensor)
assert T3.shape == (2,3,2)
data3 = np.array([[[100, 138],
[132, 186],
[164, 234]],
[[148, 204],
[196, 276],
[244, 348]]])
assert (T3.data == data3).all()

# 3-way, two matrices, explicit dimensions
T3 = tensorInstance3.ttm([M2, M3], [2,1])
assert isinstance(T3, ttb.tensor)
assert T3.shape == (2,3,2)
data3 = np.array([[[408, 576],
[498, 702],
[588, 828]],
[[456, 648],
[558, 792],
[660, 936]]])
assert (T3.data == data3).all()

# 3-way, 3 matrices, no dimensions specified
T3 = tensorInstance3.ttm([M2, M3, M2])
assert isinstance(T3, ttb.tensor)
assert T3.shape == (2,3,2)
data3 = np.array([[[1776, 2520],
[2172, 3078],
[2568, 3636]],
[[2640, 3744],
[3228, 4572],
[3816, 5400]]])
assert (T3.data == data3).all()

# 3-way, matrix must be np.ndarray
Tmat = ttb.tenmat.from_data(M2, rdims=np.array([0]))
with pytest.raises(AssertionError) as excinfo:
tensorInstance3.ttm(Tmat,0)
assert "matrix must be of type numpy.ndarray" in str(excinfo)

# 3-way, dims must be in range [0,self.ndims]
with pytest.raises(AssertionError) as excinfo:
tensorInstance3.ttm(M2, tensorInstance3.ndims + 1)
assert "dims must contain values in [0,self.dims]" in str(excinfo)

@pytest.mark.indevelopment
def test_tensor_ttv(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
(params2, tensorInstance2) = sample_tensor_2way
Expand Down

0 comments on commit 3b3abcc

Please sign in to comment.