Skip to content

Commit

Permalink
Allowing rdims or cdims to be empty array. (#43)
Browse files Browse the repository at this point in the history
Closes #42
  • Loading branch information
dmdunla committed Jul 18, 2022
1 parent e296f3a commit 2ab1934
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
13 changes: 10 additions & 3 deletions pyttb/tenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,19 @@ def from_tensor_type(cls, source, rdims=None, cdims=None, cdims_cyclic=None):
elif rdims is None and cdims is not None:
rdims = np.setdiff1d(alldims, cdims)


dims = np.hstack([rdims, cdims])
# if rdims or cdims is empty, hstack will output an array of float not int
if rdims.size == 0:
dims = cdims.copy()
elif cdims.size == 0:
dims = rdims.copy()
else:
dims = np.hstack([rdims, cdims])
if not len(dims) == n or not (alldims == np.sort(dims)).all():
assert False, 'Incorrect specification of dimensions, the sorted concatenation of rdims and cdims must be range(source.ndims).'

data = np.reshape(source.permute(dims).data, (np.prod(np.array(tshape)[rdims]), np.prod(np.array(tshape)[cdims])), order='F')
rprod = 1 if rdims.size == 0 else np.prod(np.array(tshape)[rdims])
cprod = 1 if cdims.size == 0 else np.prod(np.array(tshape)[cdims])
data = np.reshape(source.permute(dims).data, (rprod, cprod), order='F')

# Create tenmat
tenmatInstance = cls()
Expand Down
21 changes: 20 additions & 1 deletion tests/test_tenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ def sample_ndarray_2way():
params = {'data':ndarrayInstance, 'shape':shape}
return params, ndarrayInstance

@pytest.fixture()
def sample_tensor_3way():
data = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.])
shape = (2, 3, 2)
params = {'data':np.reshape(data, np.array(shape), order='F'), 'shape': shape}
tensorInstance = ttb.tensor().from_data(data, shape)
return params, tensorInstance

@pytest.fixture()
def sample_ndarray_4way():
shape = (2, 2, 2, 2)
Expand Down Expand Up @@ -184,8 +192,9 @@ def test_tenmat_initialization_from_data(sample_ndarray_1way, sample_ndarray_2wa
assert exc in str(excinfo)

@pytest.mark.indevelopment
def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tensor_4way):
def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tensor_3way, sample_tensor_4way):
(_, tensorInstance) = sample_tensor_4way
(_, tensorInstance3) = sample_tensor_3way
(params, tenmatInstance) = sample_tenmat_4way
tshape = params['tshape']
rdims = params['rdims']
Expand All @@ -208,6 +217,11 @@ def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tenso
assert tenmatInstance.shape == tenmatTensorRdims.shape
assert tenmatInstance.tshape == tenmatTensorRdims.tshape

# Constructor from tensor using empty rdims
tenmatTensorRdims = ttb.tenmat.from_tensor_type(tensorInstance3, rdims=np.array([]))
data = np.reshape(np.arange(1,13),(1,12))
assert (tenmatTensorRdims.data == data).all()

# Constructor from tensor using cdims only
tenmatTensorCdims = ttb.tenmat.from_tensor_type(tensorInstance, cdims=cdims)
assert (tenmatInstance.data == tenmatTensorCdims.data).all()
Expand All @@ -216,6 +230,11 @@ def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tenso
assert tenmatInstance.shape == tenmatTensorCdims.shape
assert tenmatInstance.tshape == tenmatTensorCdims.tshape

# Constructor from tensor using empty cdims
tenmatTensorCdims = ttb.tenmat.from_tensor_type(tensorInstance3, cdims=np.array([]))
data = np.reshape(np.arange(1,13),(12,1))
assert (tenmatTensorCdims.data == data).all()

# Constructor from tensor using rdims and cdims
tenmatTensorRdimsCdims = ttb.tenmat.from_tensor_type(tensorInstance, rdims=rdims, cdims=cdims)
assert (tenmatInstance.data == tenmatTensorRdimsCdims.data).all()
Expand Down

0 comments on commit 2ab1934

Please sign in to comment.