Skip to content

Commit

Permalink
Fixing np.reshape usage. Adding more tests for tensor.ttv. (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmdunla committed Jul 13, 2022
1 parent 51cf7f6 commit 0289d97
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
12 changes: 3 additions & 9 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,6 @@ def permute(self, order):
# Np transpose does error checking on order, acts as permutation
return ttb.tensor.from_data(np.transpose(self.data, order))

# TODO should this be a property?
def reshape(self, *shape):
"""
Reshapes a tensor
Expand Down Expand Up @@ -913,6 +912,8 @@ def ttv(self, vector, dims=None):
dims = np.array([dims])

# Check that vector is a list of vectors, if not place single vector as element in list
if isinstance(vector, list):
return self.ttv(np.array(vector), dims)
if len(vector.shape) == 1 and isinstance(vector[0], (int, float, np.int_, np.float_)):
return self.ttv(np.array([vector]), dims)

Expand All @@ -924,13 +925,6 @@ def ttv(self, vector, dims=None):
if vector[vidx[i]].shape != (self.shape[dims[i]], ):
assert False, "Multiplicand is wrong size"

# TODO: not sure what this special case handles
#if exist('tensor/ttv_single', 'file') == 3:
# c = a
# for i = numel(dims): -1: 1
# c = ttv_single(c, v{vidx(i)}, dims(i))
# return c

# Extract the data
c = self.data.copy()

Expand All @@ -944,7 +938,7 @@ def ttv(self, vector, dims=None):
sz = np.array(self.shape)[np.concatenate((remdims, dims))]

for i in range(dims.size-1, -1, -1):
c = np.reshape(c, tuple([np.prod(sz[0:n-1]), sz[n-1]]))
c = np.reshape(c, tuple([np.prod(sz[0:n-1]), sz[n-1]]), order='F')
c = c.dot(vector[vidx[i]])
n -= 1
# If needed, convert the final result back to tensor
Expand Down
48 changes: 41 additions & 7 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,20 +978,54 @@ def test_tensor_squeeze(sample_tensor_2way):
assert (ttb.tensor.from_data(np.array([[1, 2, 3]])).squeeze().data == np.array([1, 2, 3])).all()

@pytest.mark.indevelopment
def test_tensor_ttv(sample_tensor_2way):
(params, tensorInstance) = sample_tensor_2way
def test_tensor_ttv(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
(params2, tensorInstance2) = sample_tensor_2way
(params3, tensorInstance3) = sample_tensor_3way
(params3, tensorInstance4) = sample_tensor_4way

# Wrong shape vector
with pytest.raises(AssertionError) as excinfo:
tensorInstance.ttv(np.array([np.array([1, 2]), np.array([1, 2])]))
tensorInstance2.ttv(np.array([np.array([1, 2]), np.array([1, 2])]))
assert "Multiplicand is wrong size" in str(excinfo)

# Multiply by single vector
assert (tensorInstance.ttv(np.array([2, 2]), 0).data == np.array([2, 2]).dot(params['data'])).all()
# 2-way Multiply by single vector
T2 = tensorInstance2.ttv(np.array([2, 2]), 0)
assert isinstance(T2, ttb.tensor)
assert T2.shape == (3,)
assert (T2.data == np.array([10,14,18])).all()

# Multiply by multiple vectors, infer dimensions
assert tensorInstance2.ttv(np.array([np.array([2, 2]), np.array([1, 1, 1])])) == 42

# Multiply by multiple vectors as list of numpy.ndarrays
assert tensorInstance2.ttv([np.array([2, 2]), np.array([1, 1, 1])]) == 42

# 3-way Multiply by single vector
T3 = tensorInstance3.ttv(2 * np.ones((tensorInstance3.shape[0],)), 0)
assert isinstance(T3, ttb.tensor)
assert T3.shape == (tensorInstance3.shape[1], tensorInstance3.shape[2])
assert (T3.data == np.array([[6,30],[14,38],[22,46]])).all()

# Multiply by multiple vectors, infer dimensions
assert tensorInstance3.ttv(np.array([np.array([2, 2]), np.array([1, 1, 1]), np.array([2, 2])])) == 312

# 4-way Multiply by single vector
T4 = tensorInstance4.ttv(2 * np.ones((tensorInstance4.shape[0],)), 0)
assert isinstance(T4, ttb.tensor)
assert T4.shape == (tensorInstance4.shape[1], tensorInstance4.shape[2], tensorInstance4.shape[3])
data4 = np.array([[[ 12, 174, 336],
[ 66, 228, 390],
[120, 282, 444]],
[[ 30, 192, 354],
[ 84, 246, 408],
[138, 300, 462]],
[[ 48, 210, 372],
[102, 264, 426],
[156, 318, 480]]])
assert (T4.data == data4).all()

# Multiply by multiple vectors, infer dimensions
assert (tensorInstance.ttv(np.array([np.array([2, 2]), np.array([1, 1, 1])])) ==
np.array([1, 1, 1]).dot(np.array([2, 2]).dot(params['data'])))
assert tensorInstance4.ttv(np.array([np.array([1, 1, 1]), np.array([1, 1, 1]), np.array([1, 1, 1]), np.array([1, 1, 1])])) == 3321

@pytest.mark.indevelopment
def test_tensor_ttsv(sample_tensor_2way):
Expand Down

0 comments on commit 0289d97

Please sign in to comment.