diff --git a/pyttb/sptensor.py b/pyttb/sptensor.py index 04332155..f6d26fc4 100644 --- a/pyttb/sptensor.py +++ b/pyttb/sptensor.py @@ -1393,7 +1393,6 @@ def __getitem__(self, item): return a - # pylint:disable=too-many-statements, too-many-branches, too-many-locals def __setitem__(self, key, value): """ Subscripted assignment for sparse tensor. @@ -1443,294 +1442,303 @@ def __setitem__(self, key, value): (isinstance(value, np.ndarray) and value.size == 0) or (isinstance(value, list) and value == []) ): - return + return None # Determine if we are doing a substenor or list of subscripts objectType = tt_assignment_type(self, key, value) # Case 1: Replace a sub-tensor - if objectType == "subtensor": # pylint:disable=too-many-nested-blocks - # Case I(a): RHS is another sparse tensor - if isinstance(value, ttb.sptensor): - # First, Resize the tensor and check the size match with the tensor - # that's being inserted. - m = 0 - newsz = [] - for n, key_n in enumerate(key): - if isinstance(key_n, slice): - if self.ndims <= n: - if key_n.stop is None: - newsz.append(value.shape[m]) - else: - newsz.append(key_n.stop) - else: - if key_n.stop is None: - newsz.append(max([self.shape[n], value.shape[m]])) - else: - newsz.append(max([self.shape[n], key_n.stop])) - m = m + 1 - elif isinstance(key_n, (float, int)): - if self.ndims <= n: - newsz.append(key_n + 1) - else: - newsz.append(max([self.shape[n], key_n + 1])) - else: - if len(key_n) != value.shape[m]: - assert False, "RHS does not match range size" - if self.ndims <= n: - newsz.append(max(key_n) + 1) - else: - newsz.append(max([self.shape[n], max(key_n) + 1])) - self.shape = tuple(newsz) - - # Expand subs array if there are new modes, i.e., if the order - # has increased. - if self.subs.size > 0 and (len(self.shape) > self.subs.shape[1]): - self.subs = np.append( + if objectType == "subtensor": + return self._set_subtensor(key, value) + # Case 2: Subscripts + if objectType == "subscripts": + return self._set_subscripts(key, value) + raise ValueError("Unknown assignment type") # pragma: no cover + + def _set_subscripts(self, key, value): + # Case II: Replacing values at specific indices + newsubs = key + if len(newsubs.shape) == 1: + newsubs = np.expand_dims(newsubs, axis=0) + tt_subscheck(newsubs, nargout=False) + + # Error check on subscripts + if newsubs.shape[1] < self.ndims: + assert False, "Invalid subscripts" + + # Check for expanding the order + if newsubs.shape[1] > self.ndims: + newshape = list(self.shape) + # TODO no need for loop, just add correct size + for _ in range(self.ndims, newsubs.shape[1]): + newshape.append(1) + if self.subs.size > 0: + self.subs = np.concatenate( + ( self.subs, - np.zeros( - shape=( - self.subs.shape[0], - len(self.shape) - self.subs.shape[1], - ) + np.ones( + (self.shape[0], newsubs.shape[1] - self.ndims), + dtype=int, ), - axis=1, - ) - # Delete what currently occupies the specified range - rmloc = self.subdims(key) - kploc = np.setdiff1d(range(0, self.nnz), rmloc) - # TODO: evaluate solution for assigning value to empty sptensor - if len(self.subs.shape) > 1: - newsubs = self.subs[kploc.astype(int), :] - else: - newsubs = self.subs[kploc.astype(int)] - newvals = self.vals[kploc.astype(int)] - - # Renumber the subscripts - addsubs = ttb.tt_irenumber(value, self.shape, key) - if newsubs.size > 0 and addsubs.size > 0: - self.subs = np.vstack((newsubs, addsubs)) - self.vals = np.vstack((newvals, value.vals)) - elif newsubs.size > 0: - self.subs = newsubs - self.vals = newvals - elif addsubs.size > 0: - self.subs = addsubs - self.vals = value.vals - else: - self.subs = np.array([], ndmin=2, dtype=int) - self.vals = np.array([], ndmin=2) - - return - # Case I(b): Value is zero or scalar + ), + axis=1, + ) + self.shape = tuple(newshape) - # First, resize the tensor, determine new size of existing modes + # Copy rhs to newvals + newvals = value + + if isinstance(newvals, (float, int)): + newvals = np.expand_dims([newvals], axis=1) + + # Error check the rhs is a column vector. We don't bother to handle any + # other type with sparse tensors + tt_valscheck(newvals, nargout=False) + + # Determine number of nonzeros being inserted. + # (This is determined by number of subscripts) + newnnz = newsubs.shape[0] + + # Error check on size of newvals + if newvals.size == 1: + # Special case where newvals is a single element to be assigned + # to multiple LHS. Fix to correct size + newvals = newvals * np.ones((newnnz, 1)) + + elif newvals.shape[0] != newnnz: + # Sizes don't match + assert False, "Number of subscripts and number of values do not match!" + + # Remove duplicates and print warning if any duplicates were removed + newsubs, idx = np.unique(newsubs, axis=0, return_index=True) + if newsubs.shape[0] != newnnz: + warnings.warn("Duplicate assignments discarded") + + newvals = newvals[idx] + + # Find which subscripts already exist and their locations + tf = ttb.tt_ismember_rows(newsubs, self.subs) + loc = np.where(tf >= 0)[0].astype(int) + + # Split into three groups for processing: + # + # Group A: Elements that already exist and need to be changed + # Group B: Elements that already exist and need to be removed + # Group C: Elements that do not exist and need to be added + # + # Note that we are ignoring any new zero elements, because + # those obviously do not need to be added. Also, it's + # important to process Group A before Group B because the + # processing of Group B may change the locations of the + # remaining elements. + + # TF+1 for logical consideration because 0 is valid index + # and -1 is our null flag + idxa = np.logical_and(tf + 1, newvals)[0] + idxb = np.logical_and(tf + 1, np.logical_not(newvals))[0] + idxc = np.logical_and(np.logical_not(tf + 1), newvals)[0] + + # Process Group A: Changing values + if np.sum(idxa) > 0: + self.vals[tf[idxa]] = newvals[idxa] + # Proces Group B: Removing Values + if np.sum(idxb) > 0: + removesubs = loc[idxb] + keepsubs = np.setdiff1d(range(0, self.nnz), removesubs) + self.subs = self.subs[keepsubs, :] + self.vals = self.vals[keepsubs] + # Process Group C: Adding new, nonzero values + if np.sum(idxc) > 0: + if self.subs.size > 0: + self.subs = np.vstack((self.subs, newsubs[idxc, :])) + self.vals = np.vstack((self.vals, newvals[idxc])) + else: + self.subs = newsubs[idxc, :] + self.vals = newvals[idxc] + + # Resize the tensor + newshape = [] + for n, dim in enumerate(self.shape): + smax = max(newsubs[:, n] + 1) + newshape.append(max(dim, smax)) + self.shape = tuple(newshape) + + # pylint:disable=too-many-statements + def _set_subtensor(self, key, value): + # Case I(a): RHS is another sparse tensor + if isinstance(value, ttb.sptensor): + # First, Resize the tensor and check the size match with the tensor + # that's being inserted. + m = 0 newsz = [] - for n in range(0, self.ndims): - if isinstance(key[n], slice): - if key[n].stop is None: - newsz.append(self.shape[n]) + for n, key_n in enumerate(key): + if isinstance(key_n, slice): + if self.ndims <= n: + if key_n.stop is None: + newsz.append(value.shape[m]) + else: + newsz.append(key_n.stop) else: - newsz.append(max([self.shape[n], key[n].stop])) - else: - newsz.append(max([self.shape[n], key[n] + 1])) - - # Determine size of new modes, if any - for n in range(self.ndims, len(key)): - if isinstance(key[n], slice): - if key[n].stop is None: - assert False, ( - "Must have well defined slice when expanding sptensor " - "shape with setitem" - ) + if key_n.stop is None: + newsz.append(max([self.shape[n], value.shape[m]])) + else: + newsz.append(max([self.shape[n], key_n.stop])) + m = m + 1 + elif isinstance(key_n, (float, int)): + if self.ndims <= n: + newsz.append(key_n + 1) else: - newsz.append(key[n].stop) - elif isinstance(key[n], np.ndarray): - newsz.append(max(key[n]) + 1) + newsz.append(max([self.shape[n], key_n + 1])) else: - newsz.append(key[n] + 1) + if len(key_n) != value.shape[m]: + assert False, "RHS does not match range size" + if self.ndims <= n: + newsz.append(max(key_n) + 1) + else: + newsz.append(max([self.shape[n], max(key_n) + 1])) self.shape = tuple(newsz) - # Expand subs array if there are new modes, i.e. if the order has increased - if self.subs.size > 0 and len(self.shape) > self.subs.shape[1]: + # Expand subs array if there are new modes, i.e., if the order + # has increased. + if self.subs.size > 0 and (len(self.shape) > self.subs.shape[1]): self.subs = np.append( self.subs, np.zeros( - shape=(self.subs.shape[0], len(self.shape) - self.subs.shape[1]) + shape=( + self.subs.shape[0], + len(self.shape) - self.subs.shape[1], + ) ), axis=1, ) + # Delete what currently occupies the specified range + rmloc = self.subdims(key) + kploc = np.setdiff1d(range(0, self.nnz), rmloc) + # TODO: evaluate solution for assigning value to empty sptensor + if len(self.subs.shape) > 1: + newsubs = self.subs[kploc.astype(int), :] + else: + newsubs = self.subs[kploc.astype(int)] + newvals = self.vals[kploc.astype(int)] + + # Renumber the subscripts + addsubs = ttb.tt_irenumber(value, self.shape, key) + if newsubs.size > 0 and addsubs.size > 0: + self.subs = np.vstack((newsubs, addsubs)) + self.vals = np.vstack((newvals, value.vals)) + elif newsubs.size > 0: + self.subs = newsubs + self.vals = newvals + elif addsubs.size > 0: + self.subs = addsubs + self.vals = value.vals + else: + self.subs = np.array([], ndmin=2, dtype=int) + self.vals = np.array([], ndmin=2) - # Case I(b)i: Zero right-hand side - if isinstance(value, (int, float)) and value == 0: - # Delete what currently occupies the specified range - rmloc = self.subdims(key) - kploc = np.setdiff1d(range(0, self.nnz), rmloc).astype(int) - self.subs = self.subs[kploc, :] - self.vals = self.vals[kploc] - return - - # Case I(b)ii: Scalar Right Hand Side - if isinstance(value, (int, float)): - # Determine number of dimensions (may be larger than current number) - N = len(key) - keyCopy = np.array(key) - # Figure out how many indices are in each dimension - nssubs = np.zeros((N, 1)) - for n in range(0, N): - if isinstance(key[n], slice): - # Generate slice explicitly to determine its length - keyCopy[n] = np.arange(0, self.shape[n])[key[n]] - indicesInN = len(keyCopy[n]) - else: - indicesInN = 1 - nssubs[n] = indicesInN - - # Preallocate (discover any memory issues here!) - addsubs = np.zeros((np.prod(nssubs).astype(int), N)) + return + # Case I(b): Value is zero or scalar - # Generate appropriately sized ones vectors - o = [] - for n in range(N): - o.append(np.ones((int(nssubs[n]), 1))) + # First, resize the tensor, determine new size of existing modes + newsz = [] + for n in range(0, self.ndims): + if isinstance(key[n], slice): + if key[n].stop is None: + newsz.append(self.shape[n]) + else: + newsz.append(max([self.shape[n], key[n].stop])) + else: + newsz.append(max([self.shape[n], key[n] + 1])) + + # Determine size of new modes, if any + for n in range(self.ndims, len(key)): + if isinstance(key[n], slice): + if key[n].stop is None: + assert False, ( + "Must have well defined slice when expanding sptensor " + "shape with setitem" + ) + else: + newsz.append(key[n].stop) + elif isinstance(key[n], np.ndarray): + newsz.append(max(key[n]) + 1) + else: + newsz.append(key[n] + 1) + self.shape = tuple(newsz) - # Generate each column of the subscripts in turn - for n in range(N): - i = o.copy() - if not np.isscalar(keyCopy[n]): - i[n] = np.array(keyCopy[n])[:, None] - else: - i[n] = np.array(keyCopy[n], ndmin=2) - addsubs[:, n] = ttb.khatrirao(i).transpose()[:] + # Expand subs array if there are new modes, i.e. if the order has increased + if self.subs.size > 0 and len(self.shape) > self.subs.shape[1]: + self.subs = np.append( + self.subs, + np.zeros( + shape=(self.subs.shape[0], len(self.shape) - self.subs.shape[1]) + ), + axis=1, + ) - if self.subs.size > 0: - # Replace existing values - loc = ttb.tt_intersect_rows(self.subs, addsubs) - self.vals[loc] = value - # pare down list of subscripts to add - addsubs = addsubs[ttb.tt_setdiff_rows(addsubs, self.subs)] - - # If there are things to insert then insert them - if addsubs.size > 0: - if self.subs.size > 0: - self.subs = np.vstack((self.subs, addsubs.astype(int))) - self.vals = np.vstack( - (self.vals, value * np.ones((addsubs.shape[0], 1))) - ) - else: - self.subs = addsubs.astype(int) - self.vals = value * np.ones(addsubs.shape[0]) - return + # Case I(b)i: Zero right-hand side + if isinstance(value, (int, float)) and value == 0: + # Delete what currently occupies the specified range + rmloc = self.subdims(key) + kploc = np.setdiff1d(range(0, self.nnz), rmloc).astype(int) + self.subs = self.subs[kploc, :] + self.vals = self.vals[kploc] + return - assert False, "Invalid assignment value" + # Case I(b)ii: Scalar Right Hand Side + if isinstance(value, (int, float)): + # Determine number of dimensions (may be larger than current number) + N = len(key) + keyCopy = np.array(key) + # Figure out how many indices are in each dimension + nssubs = np.zeros((N, 1)) + for n in range(0, N): + if isinstance(key[n], slice): + # Generate slice explicitly to determine its length + keyCopy[n] = np.arange(0, self.shape[n])[key[n]] + indicesInN = len(keyCopy[n]) + else: + indicesInN = 1 + nssubs[n] = indicesInN + + # Preallocate (discover any memory issues here!) + addsubs = np.zeros((np.prod(nssubs).astype(int), N)) + + # Generate appropriately sized ones vectors + o = [] + for n in range(N): + o.append(np.ones((int(nssubs[n]), 1))) + + # Generate each column of the subscripts in turn + for n in range(N): + i = o.copy() + if not np.isscalar(keyCopy[n]): + i[n] = np.array(keyCopy[n])[:, None] + else: + i[n] = np.array(keyCopy[n], ndmin=2) + addsubs[:, n] = ttb.khatrirao(i).transpose()[:] - # Case 2: Subscripts - elif objectType == "subscripts": - # Case II: Replacing values at specific indices - - newsubs = key - if len(newsubs.shape) == 1: - newsubs = np.expand_dims(newsubs, axis=0) - tt_subscheck(newsubs, nargout=False) - - # Error check on subscripts - if newsubs.shape[1] < self.ndims: - assert False, "Invalid subscripts" - - # Check for expanding the order - if newsubs.shape[1] > self.ndims: - newshape = list(self.shape) - for i in range(self.ndims, newsubs.shape[1]): - newshape.append(1) + if self.subs.size > 0: + # Replace existing values + loc = ttb.tt_intersect_rows(self.subs, addsubs) + self.vals[loc] = value + # pare down list of subscripts to add + addsubs = addsubs[ttb.tt_setdiff_rows(addsubs, self.subs)] + + # If there are things to insert then insert them + if addsubs.size > 0: if self.subs.size > 0: - self.subs = np.concatenate( - ( - self.subs, - np.ones( - (self.shape[0], newsubs.shape[1] - self.ndims), - dtype=int, - ), - ), - axis=1, + self.subs = np.vstack((self.subs, addsubs.astype(int))) + self.vals = np.vstack( + (self.vals, value * np.ones((addsubs.shape[0], 1))) ) - self.shape = tuple(newshape) - - # Copy rhs to newvals - newvals = value - - if isinstance(newvals, (float, int)): - newvals = np.expand_dims([newvals], axis=1) - - # Error check the rhs is a column vector. We don't bother to handle any - # other type with sparse tensors - tt_valscheck(newvals, nargout=False) - - # Determine number of nonzeros being inserted. - # (This is determined by number of subscripts) - newnnz = newsubs.shape[0] - - # Error check on size of newvals - if newvals.size == 1: - # Special case where newvals is a single element to be assigned - # to multiple LHS. Fix to correct size - newvals = newvals * np.ones((newnnz, 1)) - - elif newvals.shape[0] != newnnz: - # Sizes don't match - assert False, "Number of subscripts and number of values do not match!" - - # Remove duplicates and print warning if any duplicates were removed - newsubs, idx = np.unique(newsubs, axis=0, return_index=True) - if newsubs.shape[0] != newnnz: - warnings.warn("Duplicate assignments discarded") - - newvals = newvals[idx] - - # Find which subscripts already exist and their locations - tf = ttb.tt_ismember_rows(newsubs, self.subs) - loc = np.where(tf >= 0)[0].astype(int) - - # Split into three groups for processing: - # - # Group A: Elements that already exist and need to be changed - # Group B: Elements that already exist and need to be removed - # Group C: Elements that do not exist and need to be added - # - # Note that we are ignoring any new zero elements, because - # those obviously do not need to be added. Also, it's - # important to process Group A before Group B because the - # processing of Group B may change the locations of the - # remaining elements. - - # TF+1 for logical consideration because 0 is valid index - # and -1 is our null flag - idxa = np.logical_and(tf + 1, newvals)[0] - idxb = np.logical_and(tf + 1, np.logical_not(newvals))[0] - idxc = np.logical_and(np.logical_not(tf + 1), newvals)[0] - - # Process Group A: Changing values - if np.sum(idxa) > 0: - self.vals[tf[idxa]] = newvals[idxa] - # Proces Group B: Removing Values - if np.sum(idxb) > 0: - removesubs = loc[idxb] - keepsubs = np.setdiff1d(range(0, self.nnz), removesubs) - self.subs = self.subs[keepsubs, :] - self.vals = self.vals[keepsubs] - # Process Group C: Adding new, nonzero values - if np.sum(idxc) > 0: - self.subs = np.vstack((self.subs, newsubs[idxc, :])) - self.vals = np.vstack((self.vals, newvals[idxc])) - - # Resize the tensor - newshape = [] - for n, dim in enumerate(self.shape): - smax = max(newsubs[:, n] + 1) - newshape.append(max(dim, smax)) - self.shape = tuple(newshape) - + else: + self.subs = addsubs.astype(int) + self.vals = value * np.ones((addsubs.shape[0], 1)) return + assert False, "Invalid assignment value" + def __eq__(self, other): """ Equal comparator for sptensors diff --git a/pyttb/tensor.py b/pyttb/tensor.py index 6acae9f6..3bebc6ea 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -1247,9 +1247,7 @@ def ttsv( return y assert False, "Invalid value for version; should be None, 1, or 2" - def __setitem__( - self, key, value - ): # pylint: disable=too-many-branches, too-many-statements + def __setitem__(self, key, value): """ SUBSASGN Subscripted assignment for a tensor. @@ -1298,103 +1296,109 @@ def __setitem__( # Case 1: Rectangular Subtensor if access_type == "subtensor": - # Extract array of subscripts - subs = key - - # Will the size change? If so we first need to resize x - n = self.ndims - sliceCheck = [] - for element in subs: - if isinstance(element, slice): - if element.stop is None: - sliceCheck.append(1) - else: - sliceCheck.append(element.stop) - else: - sliceCheck.append(element) - bsiz = np.array(sliceCheck) - if n == 0: - newsiz = (bsiz[n:] + 1).astype(int) - else: - newsiz = np.concatenate( - (np.max((self.shape, bsiz[0:n] + 1), axis=0), bsiz[n:] + 1) - ).astype(int) - if (newsiz != self.shape).any(): - # We need to enlarge x.data. - newData = np.zeros(shape=tuple(newsiz)) - idx = [slice(None, currentShape) for currentShape in self.shape] - if self.data.size > 0: - newData[tuple(idx)] = self.data - self.data = newData - - self.shape = tuple(newsiz) - if isinstance(value, ttb.tensor): - self.data[key] = value.data - else: - self.data[key] = value - - return + return self._set_subtensor(key, value) # Case 2a: Subscript indexing if access_type == "subscripts": - # Extract array of subscripts - subs = key - - # Will the size change? If so we first need to resize x - n = self.ndims - if ( - len(subs.shape) == 1 - and len(self.shape) == 1 - and self.shape[0] < subs.shape[0] - ): - bsiz = subs - elif len(subs.shape) == 1: - bsiz = np.array([np.max(subs, axis=0)]) - key = key.tolist() - else: - bsiz = np.array(np.max(subs, axis=0)) - if n == 0: - newsiz = (bsiz[n:] + 1).astype(int) - else: - newsiz = np.concatenate( - (np.max((self.shape, bsiz[0:n] + 1), axis=0), bsiz[n:] + 1) - ).astype(int) + return self._set_subscripts(key, value) - if (newsiz != self.shape).any(): - # We need to enlarge x.data. - newData = np.zeros(shape=tuple(newsiz)) - idx = [slice(None, currentShape) for currentShape in self.shape] - if self.data.size > 0: - newData[idx] = self.data - self.data = newData + # Case 2b: Linear Indexing + if access_type == "linear indices": + return self._set_linear(key, value) - self.shape = tuple(newsiz) + assert False, "Invalid use of tensor setitem" - # Finally we can copy in new data - if isinstance(key, list): - self.data[key] = value - elif key.shape[0] == 1: # and len(key.shape) == 1: - self.data[tuple(key[0, :])] = value - else: - self.data[tuple(key)] = value - return + def _set_linear(self, key, value): + idx = key + if (idx > np.prod(self.shape)).any(): + assert ( + False + ), "TTB:BadIndex In assignment X[I] = Y, a tensor X cannot be resized" + idx = tt_ind2sub(self.shape, idx) + if idx.shape[0] == 1: + self.data[tuple(idx[0, :])] = value + else: + actualIdx = tuple(idx.transpose()) + self.data[actualIdx] = value - # Case 2b: Linear Indexing - if access_type == "linear indices": - idx = key - if (idx > np.prod(self.shape)).any(): - assert ( - False - ), "TTB:BadIndex In assignment X[I] = Y, a tensor X cannot be resized" - idx = tt_ind2sub(self.shape, idx) - if idx.shape[0] == 1: - self.data[tuple(idx[0, :])] = value + def _set_subtensor(self, key, value): + # Extract array of subscripts + subs = key + # Will the size change? If so we first need to resize x + n = self.ndims + sliceCheck = [] + for element in subs: + if isinstance(element, slice): + if element.stop is None: + sliceCheck.append(1) + else: + sliceCheck.append(element.stop) else: - actualIdx = tuple(idx.transpose()) - self.data[actualIdx] = value - return + sliceCheck.append(element) + bsiz = np.array(sliceCheck) + if n == 0: + newsiz = (bsiz[n:] + 1).astype(int) + else: + newsiz = np.concatenate( + (np.max((self.shape, bsiz[0:n] + 1), axis=0), bsiz[n:] + 1) + ).astype(int) + if not np.array_equal(newsiz, self.shape): + # We need to enlarge x.data. + newData = np.zeros(shape=tuple(newsiz)) + if self.data.size > 0: + idx = [slice(None, currentShape) for currentShape in self.shape] + idx.extend([0] * (len(newsiz) - self.ndims)) + newData[tuple(idx)] = self.data + self.data = newData - assert False, "Invalid use of tensor setitem" + self.shape = tuple(newsiz) + if isinstance(value, ttb.tensor): + self.data[key] = value.data + else: + self.data[key] = value + + def _set_subscripts(self, key, value): + # Extract array of subscripts + subs = key + + # Will the size change? If so we first need to resize x + n = self.ndims + if ( + len(subs.shape) == 1 + and len(self.shape) == 1 + and self.shape[0] < subs.shape[0] + ): + bsiz = subs + elif len(subs.shape) == 1: + bsiz = np.array([np.max(subs, axis=0)]) + key = key.tolist() + else: + bsiz = np.array(np.max(subs, axis=0)) + if n == 0: + newsiz = (bsiz[n:] + 1).astype(int) + else: + newsiz = np.concatenate( + (np.max((self.shape, bsiz[0:n] + 1), axis=0), bsiz[n:] + 1) + ).astype(int) + + if not np.array_equal(newsiz, self.shape): + # We need to enlarge x.data. + newData = np.zeros(shape=tuple(newsiz)) + if self.data.size > 0: + idx = [slice(None, currentShape) for currentShape in self.shape] + idx.extend([0] * (len(newsiz) - self.ndims)) + newData[tuple(idx)] = self.data + self.data = newData + + self.shape = tuple(newsiz) + + # Finally we can copy in new data + if isinstance(key, list): + self.data[key] = value + elif key.shape[0] == 1: # and len(key.shape) == 1: + self.data[tuple(key[0, :])] = value + else: + self.data[tuple(key)] = value def __getitem__(self, item): """ diff --git a/tests/test_sptensor.py b/tests/test_sptensor.py index 8efb308c..3a14ccdb 100644 --- a/tests/test_sptensor.py +++ b/tests/test_sptensor.py @@ -441,6 +441,12 @@ def test_sptensor_setitem_Case1(sample_sptensor): assert (sptensorInstance.vals == data["vals"][subSelection]).all() assert sptensorInstance.shape == data["shape"] + # Case I(b)i: Set with non-zero, no subs exist + empty_tensor = ttb.sptensor() + empty_tensor[0, 0] = 1 + # Validate entry worked correctly + empty_tensor.__repr__() + # Case I(b)i: Set with zero, sub doesn't exist sptensorInstance[1, 1, 3] = old_value reorder = [0, 2, 3, 1] @@ -555,6 +561,11 @@ def test_sptensor_setitem_Case2(sample_sptensor): ) assert "Duplicate assignments discarded" in str(record[0].message) + # Case II: Single entry, no subs exist + empty_tensor = ttb.sptensor() + empty_tensor[np.array([[0, 1], [2, 2]])] = 4 + assert np.all(empty_tensor[np.array([[0, 1], [2, 2]])] == 4) + # Case II: Single entry, for single sub that exists sptensorInstance[np.array([1, 1, 1]).astype(int)] = 999.0 assert (sptensorInstance[np.array([[1, 1, 1]])] == np.array([[999]])).all() diff --git a/tests/test_tensor.py b/tests/test_tensor.py index e78c2115..8894e0ea 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -244,6 +244,13 @@ def test_tensor__setitem__(sample_tensor_2way): tensorInstance[:, :] = tensorInstance assert (tensorInstance.data == dataGrowth).all() + # Subtensor add element to empty tensor + empty_tensor = ttb.tensor() + empty_tensor[0, 0] = 1 + + # Subtensor add dimension + empty_tensor[0, 0, 0] = 2 + # Subscripts with constant tensorInstance[np.array([[1, 1]])] = 13.0 dataGrowth[1, 1] = 13.0 @@ -259,6 +266,15 @@ def test_tensor__setitem__(sample_tensor_2way): dataGrowth[([1, 1], [1, 2])] = 13.0 assert (tensorInstance.data == dataGrowth).all() + # Subscripts add element to empty tensor + empty_tensor = ttb.tensor() + first_arbitrary_index = np.array([[0, 1], [2, 2]]) + second_arbitrary_index = np.array([[1, 2], [3, 3]]) + value = 4 + empty_tensor[first_arbitrary_index] = value + # Subscripts grow existing tensor + empty_tensor[second_arbitrary_index] = value + # Linear Index with constant tensorInstance[np.array([0])] = 13.0 dataGrowth[np.unravel_index([0], dataGrowth.shape, "F")] = 13.0