From 06f56ae9b682a38055e172478af639b10a43e136 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 11 May 2021 22:20:59 +0200 Subject: [PATCH] updated getitem to distributed constants for slices, ADVANCED INDEXING NOT CORRECTED YET --- heat/core/dndarray.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 61bc45d9b3..6d3674a98b 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -782,6 +782,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar chunk_ends = chunk_starts + counts chunk_start = chunk_starts[rank] chunk_end = chunk_ends[rank] + active_rank = None if len(key) == 0: # handle empty list # this will return an array of shape (0, ...) @@ -798,6 +799,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar and len(key[self.split]) > 1 ): # advanced indexing, elements in the split dimension are adjusted to the local indices + # todo: handle the single element return...for loop? need to find which + # process the data is on lkey = list(key) if isinstance(key[self.split], DNDarray): lkey[self.split] = key[self.split].larray @@ -849,6 +852,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar sp_pr = torch.where(key_stop >= chunk_starts)[0] sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 actives = list(range(st_pr, sp_pr + 1)) + if len(actives) == 1: + active_rank = actives[0] if rank in actives: key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] @@ -895,8 +900,11 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar arr = self.__array[tuple(key)].reshape(tuple(lout)) else: arr = torch.empty(tuple(lout), dtype=self.larray.dtype, device=self.larray.device) - # broadcast result + + if gout_full == (1,) or gout_full == [1]: + # broadcast result if the output is a single element (constant) arr = self.comm.bcast(arr, root=active_rank) + new_split = None # if 0 in arr.shape: # # no data on process