Skip to content

Commit

Permalink
updated getitem to distributed constants for slices, ADVANCED INDEXIN…
Browse files Browse the repository at this point in the history
…G NOT CORRECTED YET
  • Loading branch information
coquelin77 committed May 11, 2021
1 parent 813369d commit 06f56ae
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
chunk_ends = chunk_starts + counts
chunk_start = chunk_starts[rank]
chunk_end = chunk_ends[rank]
active_rank = None

if len(key) == 0: # handle empty list
# this will return an array of shape (0, ...)
Expand All @@ -798,6 +799,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
and len(key[self.split]) > 1
):
# advanced indexing, elements in the split dimension are adjusted to the local indices
# todo: handle the single element return...for loop? need to find which
# process the data is on
lkey = list(key)
if isinstance(key[self.split], DNDarray):
lkey[self.split] = key[self.split].larray
Expand Down Expand Up @@ -849,6 +852,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
sp_pr = torch.where(key_stop >= chunk_starts)[0]
sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0
actives = list(range(st_pr, sp_pr + 1))
if len(actives) == 1:
active_rank = actives[0]
if rank in actives:
key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank]
key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank]
Expand Down Expand Up @@ -895,8 +900,11 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
arr = self.__array[tuple(key)].reshape(tuple(lout))
else:
arr = torch.empty(tuple(lout), dtype=self.larray.dtype, device=self.larray.device)
# broadcast result

if gout_full == (1,) or gout_full == [1]:
# broadcast result if the output is a single element (constant)
arr = self.comm.bcast(arr, root=active_rank)
new_split = None

# if 0 in arr.shape:
# # no data on process
Expand Down

0 comments on commit 06f56ae

Please sign in to comment.