Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement distributed unfold operation #1419

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
4f35da8
implemented the easy cases and a simple test
FOsterfeld Mar 19, 2024
ff74451
general case
FOsterfeld Apr 2, 2024
ff38eb2
exception handling, added test with two unfold (2D slices)
FOsterfeld Apr 2, 2024
b95d40a
added unfold to manipulations module
FOsterfeld Apr 2, 2024
e01daf0
added test
FOsterfeld Apr 2, 2024
31fd8b4
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Apr 2, 2024
b334002
fixed behavior for empty unfold_loc, exception handling for size - 1 …
FOsterfeld Apr 3, 2024
1747481
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Apr 3, 2024
2e04c11
wrong exception type in test
FOsterfeld Apr 3, 2024
4e9bbe2
fixed wrong exception type in tests
FOsterfeld Apr 8, 2024
ad9c797
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Apr 8, 2024
c28b99c
fixed test for single node setting
FOsterfeld Apr 8, 2024
f67ef7e
added better docstring
FOsterfeld Apr 10, 2024
b40a715
added test to cover case that there are no fully local unfolds for a …
FOsterfeld Apr 10, 2024
8b01812
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Apr 10, 2024
713e2ad
fixed test case of no fully local unfolds
FOsterfeld Apr 10, 2024
b323da8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
d833a77
fixed error due to unspecified torch device
FOsterfeld Apr 16, 2024
e7d4cea
Merge branch 'features/1400-Implement_unfold-operation_similar_to_tor…
FOsterfeld Apr 16, 2024
2abed99
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Apr 16, 2024
df5c2a3
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Apr 16, 2024
2923eca
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
ClaudiaComito Apr 22, 2024
fb5a408
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
ClaudiaComito May 21, 2024
f3b2c6d
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
mrfh92 May 31, 2024
6c7ad84
added tests with different datatypes
FOsterfeld Jun 1, 2024
9ad8d69
renamed ´dimension´ to ´axis´
FOsterfeld Jun 4, 2024
778eb33
use `DNDarray.counts_displs()´
FOsterfeld Jun 4, 2024
c36d6a8
updated docstring
FOsterfeld Jun 4, 2024
5cdd986
use sanitize_axis
FOsterfeld Jun 4, 2024
63aa686
support one-sided halo
ClaudiaComito Jun 4, 2024
9b059f8
Merge pull request #1509 from helmholtz-analytics/features/allow-ones…
FOsterfeld Jun 5, 2024
04d1217
use `DNDarray.array_with_halos`
FOsterfeld Jun 5, 2024
b4d8c6c
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Jun 5, 2024
bcd64aa
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
ClaudiaComito Jun 10, 2024
eed36bc
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
mrfh92 Jun 13, 2024
61a5512
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
ClaudiaComito Jun 21, 2024
dca31f0
fixed condition for empty local unfold data
FOsterfeld Jul 5, 2024
2acebe7
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Jul 5, 2024
82d83ea
more tests
FOsterfeld Jul 5, 2024
562d9a0
detach after cloning
FOsterfeld Jul 5, 2024
825979c
test: blocking send in get_halo()
FOsterfeld Jul 5, 2024
e64291e
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
mrfh92 Jul 11, 2024
e67b415
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
ClaudiaComito Jul 15, 2024
3d892d1
replaced Send by Isend in "next"
Jul 15, 2024
10e7444
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
Jul 19, 2024
20c3bc2
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
mrfh92 Jul 19, 2024
6d92fb9
int64 in batchparallel clustering predict
Jul 23, 2024
757116b
Merge branch 'features/1400-Implement_unfold-operation_similar_to_tor…
Jul 23, 2024
76efe78
added error for size=1
FOsterfeld Jul 23, 2024
cd01cbb
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Jul 23, 2024
e6ef047
Update batchparallelclustering.py
mrfh92 Jul 23, 2024
1d1eb6e
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
mrfh92 Aug 12, 2024
abea89d
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
mrfh92 Aug 13, 2024
d00569a
Removed old/dead code, resolved review
FOsterfeld Aug 17, 2024
1af7076
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
FOsterfeld Aug 17, 2024
d2e3ce9
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
ClaudiaComito Aug 19, 2024
c03db4c
Merge branch 'main' into features/1400-Implement_unfold-operation_sim…
ClaudiaComito Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,18 @@ def __prephalo(self, start, end) -> torch.Tensor:

return self.__array[ix].clone().contiguous()

def get_halo(self, halo_size: int) -> torch.Tensor:
def get_halo(self, halo_size: int, prev: bool = True, next: bool = True) -> torch.Tensor:
"""
Fetch halos of size ``halo_size`` from neighboring ranks and save them in ``self.halo_next/self.halo_prev``.

Parameters
----------
halo_size : int
Size of the halo.
prev : bool, optional
If True, fetch the halo from the previous rank. Default: True.
next : bool, optional
If True, fetch the halo from the next rank. Default: True.
"""
if not isinstance(halo_size, int):
raise TypeError(
Expand Down Expand Up @@ -433,25 +437,29 @@ def get_halo(self, halo_size: int) -> torch.Tensor:
req_list = []

# exchange data with next populated process
if rank != last_rank:
self.comm.Isend(a_next, next_rank)
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=next_rank))
if prev:
if rank != last_rank:
self.comm.Isend(a_next, next_rank)
if rank != first_rank:
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=prev_rank))

if rank != first_rank:
self.comm.Isend(a_prev, prev_rank)
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=prev_rank))
if next:
if rank != first_rank:
req_list.append(self.comm.Isend(a_prev, prev_rank))
if rank != last_rank:
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=next_rank))

for req in req_list:
req.Wait()

self.__halo_next = res_prev
self.__halo_prev = res_next
self.__halo_next = res_next
self.__halo_prev = res_prev
self.__ishalo = True

def __cat_halo(self) -> torch.Tensor:
Expand Down
90 changes: 90 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"unique",
"vsplit",
"vstack",
"unfold",
]


Expand Down Expand Up @@ -4213,3 +4214,92 @@ def mpi_topk(a, b, mpi_type):


MPI_TOPK = MPI.Op.Create(mpi_topk, commute=True)


def unfold(a: DNDarray, axis: int, size: int, step: int = 1):
"""
Returns a DNDarray which contains all slices of size `size` in the axis `axis`.

Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)

Parameters
----------
a : DNDarray
array to unfold
axis : int
axis in which unfolding happens
size : int
the size of each slice that is unfolded, must be greater than 1
step : int
the step between each slice, must be at least 1

Example:
```
>>> x = ht.arange(1., 8)
>>> x
DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e)
>>> ht.unfold(x, 0, 2, 1)
DNDarray([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.]], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.unfold(x, 0, 2, 2)
DNDarray([[1., 2.],
[3., 4.],
[5., 6.]], dtype=ht.float32, device=cpu:0, split=None)
```

FOsterfeld marked this conversation as resolved.
Show resolved Hide resolved
Note
---------
You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis.
"""
if step < 1:
raise ValueError("step must be >= 1.")
if size <= 1:
raise ValueError("size must be > 1.")
axis = stride_tricks.sanitize_axis(a.shape, axis)
if size > a.shape[axis]:
raise ValueError(
f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}."
)

comm = a.comm
dev = a.device
tdev = dev.torch_device

if a.split is None or comm.size == 1 or a.split != axis: # early out
ret = factories.array(
a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm
)

return ret
else: # comm.size > 1 and split axis == unfold axis
# index range [0:sizedim-1-(size-1)] = [0:sizedim-size]
# --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1
# ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size)

if (size - 1 > a.lshape_map[:, axis]).any():
raise RuntimeError("Chunk-size needs to be at least size - 1.")
a.get_halo(size - 1, prev=False)

counts, displs = a.counts_displs()
displs = torch.tensor(displs, device=tdev)

# min local index in unfold axis
min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank]
if min_index >= a.lshape[axis] or (
comm.rank == comm.size - 1 and min_index + size > a.lshape[axis]
):
loc_unfold_shape = list(a.lshape)
loc_unfold_shape[axis] = 0
ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev)
else: # unfold has local data
ret_larray = a.array_with_halos[
axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis)
].unfold(axis, size, step)

ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm)

return ret
60 changes: 60 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3752,3 +3752,63 @@ def test_vstack(self):
b = ht.ones((12,), split=0)
res = ht.vstack((a, b))
self.assertEqual(res.shape, (2, 12))

def test_unfold(self):
dtypes = (ht.int, ht.float)

for dtype in dtypes: # test with different datatypes
# exceptions
n = 1000
x = ht.arange(n, dtype=dtype)
with self.assertRaises(ValueError): # size too small
ht.unfold(x, 0, 1, 1)
with self.assertRaises(ValueError): # step too small
ht.unfold(x, 0, 2, 0)
x.resplit_(0)
min_chunk_size = x.lshape_map[:, 0].min().item()
if min_chunk_size + 2 > n: # size too large
with self.assertRaises(ValueError):
ht.unfold(x, 0, min_chunk_size + 2)
else: # size too large for chunk_size
with self.assertRaises(RuntimeError):
ht.unfold(x, 0, min_chunk_size + 2)
with self.assertRaises(ValueError): # size too large
ht.unfold(x, 0, n + 1, 1)
ht.unfold(
x, 0, min_chunk_size, min_chunk_size + 1
) # no fully local unfolds on some nodes

# 2D sliding views
n = 100

x = torch.arange(n * n).reshape((n, n))
y = ht.array(x, dtype)
y.resplit_(0)

u = x.unfold(0, 3, 3)
u = u.unfold(1, 3, 3)
u = ht.array(u)
v = ht.unfold(y, 0, 3, 3)
v = ht.unfold(v, 1, 3, 3)

self.assertTrue(ht.equal(u, v))

# more dimensions, different split axes
n = 53
k = 3 # number of dimensions
shape = k * (n,)
size = n**k

x = torch.arange(size).reshape(shape)
_y = x.clone().detach()
y = ht.array(_y, dtype)

for split in (None, *range(k)):
y.resplit_(split)
for size in range(2, 9):
for step in range(1, 21):
for dimension in range(k):
u = ht.array(x.unfold(dimension, size, step))
v = ht.unfold(y, dimension, size, step)

self.assertTrue(ht.equal(u, v))
Loading