From 4f35da8ae1a5bd98e8dcf555025750de2da0dd56 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 19 Mar 2024 16:39:52 +0100 Subject: [PATCH 01/30] implemented the easy cases and a simple test --- unfold_test.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 unfold_test.py diff --git a/unfold_test.py b/unfold_test.py new file mode 100644 index 0000000000..9a2aba7448 --- /dev/null +++ b/unfold_test.py @@ -0,0 +1,39 @@ +""" +Test module for DNDarray.unfold +""" + +import heat as ht +from mpi4py import MPI +import torch + +from heat import factories +from heat import DNDarray + + +def unfold(a: DNDarray, dimension: int, size: int, step: int): + """ + Returns a view of the original tensor which contains all slices of size size from self tensor in the dimension dimension. + + Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) + """ + comm = a.comm + dev = a.device + + if a.split is None or comm.size == 1 or a.split != dimension: # early out + ret = factories.array( + a.larray.unfold(dimension, size, step), is_split=a.split, device=dev, comm=comm + ) + + return ret + + +# tests +n = 8 + +x = torch.arange(0.0, n) +y = 2 * x.unsqueeze(1) + x.unsqueeze(0) +y = factories.array(y) +y.resplit_(1) + +print(y) +print(unfold(y, 0, 3, 1)) From ff74451421b14271317abbadaaf7e785e271c7db Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 2 Apr 2024 16:10:05 +0200 Subject: [PATCH 02/30] general case --- unfold_test.py | 89 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 81 insertions(+), 8 deletions(-) diff --git a/unfold_test.py b/unfold_test.py index 9a2aba7448..9e9fa9d70c 100644 --- a/unfold_test.py +++ b/unfold_test.py @@ -10,7 +10,7 @@ from heat import DNDarray -def unfold(a: DNDarray, dimension: int, size: int, step: int): +def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): """ Returns a view of the original tensor which contains all slices of size size from self tensor in the dimension dimension. @@ -18,6 +18,7 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int): """ comm = a.comm dev = a.device + tdev = dev.torch_device if a.split is None or comm.size == 1 or a.split != dimension: # early out ret = factories.array( @@ -25,15 +26,87 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int): ) return ret + else: # comm.size > 1 and split axis == unfold axis + # initialize the array + # a_shape = a.shape + # index range [0:sizedim-1-(size-1)] = [0:sizedim-size] + # --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1 + # ret_shape = (*a_shape[:dimension], int((a_shape[dimension]-size)/step) + 1, a_shape[dimension+1:], size) + + # ret = ht.zeros(ret_shape, device=dev, split=a.split) + + # send the needed entries in the unfold dimension from node n to n+1 or n-1 + a.get_halo(size - 1) + a_lshapes_cum = torch.hstack( + [ + torch.zeros(1, dtype=torch.int32, device=tdev), + torch.cumsum(a.lshape_map[:, dimension], 0), + ] + ) + if comm.rank == 0: + print(a_lshapes_cum) + min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[ + comm.rank + ] # min local index in unfold dimension + print(f"min_index on rank {comm.rank}: {min_index}") + unfold_loc = a.larray[ + dimension * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) + ].unfold(dimension, size, step) + ret_larray = unfold_loc + if comm.rank < comm.size - 1: # possibly unfold with halo from next rank + max_index = a.lshape[dimension] - min_index - 1 + max_index = max_index // step * step + min_index # max local index in unfold dimension + rem = max_index + size - a.lshape[dimension] + if rem > 0: # need data from halo + unfold_halo = torch.cat( + ( + a.larray[ + dimension * (slice(None, None, None),) + + (slice(max_index, None, None), Ellipsis) + ], + a.halo_next[ + dimension * (slice(None, None, None),) + + (slice(None, rem, None), Ellipsis) + ], + ), + dimension, + ).unfold(dimension, size, step) + ret_larray = torch.cat((unfold_loc, unfold_halo), dimension) + ret = factories.array(ret_larray, is_split=dimension, device=dev, comm=comm) + + return ret # tests -n = 8 +n = 100 + +# x = torch.arange(0.0, n) +# y = 2 * x.unsqueeze(1) + x.unsqueeze(0) +# y = factories.array(y) +# y.resplit_(1) + +# print(y) +# print(unfold(y, 0, 3, 1)) + +x = torch.arange(0, n) +y = factories.array(x) +y.resplit_(0) + +u = x.unfold(0, 5, 10) +u = factories.array(u) +v = unfold(y, 0, 5, 10) + +comm = u.comm +# print(v) +equal = ht.equal(u, v) + +u_shape = u.shape +v_shape = v.shape -x = torch.arange(0.0, n) -y = 2 * x.unsqueeze(1) + x.unsqueeze(0) -y = factories.array(y) -y.resplit_(1) +if comm.rank == 0: + print(f"u.shape: {u_shape}") + print(f"v.shape: {v_shape}") + print(f"torch and heat unfold equal: {equal}") + print(f"u: {u}") -print(y) -print(unfold(y, 0, 3, 1)) +print(f"v: {v}") From ff38eb21b069beecb4c4092be82670f64f08db79 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 2 Apr 2024 16:33:25 +0200 Subject: [PATCH 03/30] exception handling, added test with two unfold (2D slices) --- unfold_test.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/unfold_test.py b/unfold_test.py index 9e9fa9d70c..40db94763b 100644 --- a/unfold_test.py +++ b/unfold_test.py @@ -16,6 +16,13 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) """ + if step < 1: + raise ValueError("step must be >= 1.") + if size < 1: + raise ValueError("size must be >= 1.") + if dimension < 0 or dimension >= a.ndim: + raise ValueError(f"{dimension} is not a valid dimension of the given DNDarray.") + comm = a.comm dev = a.device tdev = dev.torch_device @@ -43,12 +50,8 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): torch.cumsum(a.lshape_map[:, dimension], 0), ] ) - if comm.rank == 0: - print(a_lshapes_cum) - min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[ - comm.rank - ] # min local index in unfold dimension - print(f"min_index on rank {comm.rank}: {min_index}") + # min local index in unfold dimension + min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[comm.rank] unfold_loc = a.larray[ dimension * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) ].unfold(dimension, size, step) @@ -78,7 +81,7 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): # tests -n = 100 +n = 20 # x = torch.arange(0.0, n) # y = 2 * x.unsqueeze(1) + x.unsqueeze(0) @@ -88,13 +91,17 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): # print(y) # print(unfold(y, 0, 3, 1)) -x = torch.arange(0, n) +x = torch.arange(0, n * n).reshape((n, n)) +# print(f"x: {x}") y = factories.array(x) y.resplit_(0) -u = x.unfold(0, 5, 10) +u = x.unfold(0, 3, 3) +u = u.unfold(1, 3, 3) u = factories.array(u) -v = unfold(y, 0, 5, 10) +v = unfold(y, 0, 3, 3) +v.resplit_(1) +v = unfold(v, 1, 3, 3) comm = u.comm # print(v) @@ -107,6 +114,6 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): print(f"u.shape: {u_shape}") print(f"v.shape: {v_shape}") print(f"torch and heat unfold equal: {equal}") - print(f"u: {u}") + # print(f"u: {u}") -print(f"v: {v}") +# print(f"v: {v}") From b95d40a54e7a6025a0abf117971d3b8e6daac061 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 2 Apr 2024 16:45:27 +0200 Subject: [PATCH 04/30] added unfold to manipulations module --- heat/core/manipulations.py | 71 ++++++++++++++++++++++++++++++++++++ unfold_test.py | 74 ++------------------------------------ 2 files changed, 73 insertions(+), 72 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 8e95c39dcb..8dcb5a8d0c 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -61,6 +61,7 @@ "unique", "vsplit", "vstack", + "unfold", ] @@ -4178,3 +4179,73 @@ def mpi_topk(a, b, mpi_type): MPI_TOPK = MPI.Op.Create(mpi_topk, commute=True) + + +def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): + """ + Returns a view of the original tensor which contains all slices of size size from self tensor in the dimension dimension. + + Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) + """ + if step < 1: + raise ValueError("step must be >= 1.") + if size < 1: + raise ValueError("size must be >= 1.") + if dimension < 0 or dimension >= a.ndim: + raise ValueError(f"{dimension} is not a valid dimension of the given DNDarray.") + + comm = a.comm + dev = a.device + tdev = dev.torch_device + + if a.split is None or comm.size == 1 or a.split != dimension: # early out + ret = factories.array( + a.larray.unfold(dimension, size, step), is_split=a.split, device=dev, comm=comm + ) + + return ret + else: # comm.size > 1 and split axis == unfold axis + # initialize the array + # a_shape = a.shape + # index range [0:sizedim-1-(size-1)] = [0:sizedim-size] + # --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1 + # ret_shape = (*a_shape[:dimension], int((a_shape[dimension]-size)/step) + 1, a_shape[dimension+1:], size) + + # ret = ht.zeros(ret_shape, device=dev, split=a.split) + + # send the needed entries in the unfold dimension from node n to n+1 or n-1 + a.get_halo(size - 1) + a_lshapes_cum = torch.hstack( + [ + torch.zeros(1, dtype=torch.int32, device=tdev), + torch.cumsum(a.lshape_map[:, dimension], 0), + ] + ) + # min local index in unfold dimension + min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[comm.rank] + unfold_loc = a.larray[ + dimension * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) + ].unfold(dimension, size, step) + ret_larray = unfold_loc + if comm.rank < comm.size - 1: # possibly unfold with halo from next rank + max_index = a.lshape[dimension] - min_index - 1 + max_index = max_index // step * step + min_index # max local index in unfold dimension + rem = max_index + size - a.lshape[dimension] + if rem > 0: # need data from halo + unfold_halo = torch.cat( + ( + a.larray[ + dimension * (slice(None, None, None),) + + (slice(max_index, None, None), Ellipsis) + ], + a.halo_next[ + dimension * (slice(None, None, None),) + + (slice(None, rem, None), Ellipsis) + ], + ), + dimension, + ).unfold(dimension, size, step) + ret_larray = torch.cat((unfold_loc, unfold_halo), dimension) + ret = factories.array(ret_larray, is_split=dimension, device=dev, comm=comm) + + return ret diff --git a/unfold_test.py b/unfold_test.py index 40db94763b..47d1d9ce4e 100644 --- a/unfold_test.py +++ b/unfold_test.py @@ -10,76 +10,6 @@ from heat import DNDarray -def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): - """ - Returns a view of the original tensor which contains all slices of size size from self tensor in the dimension dimension. - - Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) - """ - if step < 1: - raise ValueError("step must be >= 1.") - if size < 1: - raise ValueError("size must be >= 1.") - if dimension < 0 or dimension >= a.ndim: - raise ValueError(f"{dimension} is not a valid dimension of the given DNDarray.") - - comm = a.comm - dev = a.device - tdev = dev.torch_device - - if a.split is None or comm.size == 1 or a.split != dimension: # early out - ret = factories.array( - a.larray.unfold(dimension, size, step), is_split=a.split, device=dev, comm=comm - ) - - return ret - else: # comm.size > 1 and split axis == unfold axis - # initialize the array - # a_shape = a.shape - # index range [0:sizedim-1-(size-1)] = [0:sizedim-size] - # --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1 - # ret_shape = (*a_shape[:dimension], int((a_shape[dimension]-size)/step) + 1, a_shape[dimension+1:], size) - - # ret = ht.zeros(ret_shape, device=dev, split=a.split) - - # send the needed entries in the unfold dimension from node n to n+1 or n-1 - a.get_halo(size - 1) - a_lshapes_cum = torch.hstack( - [ - torch.zeros(1, dtype=torch.int32, device=tdev), - torch.cumsum(a.lshape_map[:, dimension], 0), - ] - ) - # min local index in unfold dimension - min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[comm.rank] - unfold_loc = a.larray[ - dimension * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) - ].unfold(dimension, size, step) - ret_larray = unfold_loc - if comm.rank < comm.size - 1: # possibly unfold with halo from next rank - max_index = a.lshape[dimension] - min_index - 1 - max_index = max_index // step * step + min_index # max local index in unfold dimension - rem = max_index + size - a.lshape[dimension] - if rem > 0: # need data from halo - unfold_halo = torch.cat( - ( - a.larray[ - dimension * (slice(None, None, None),) - + (slice(max_index, None, None), Ellipsis) - ], - a.halo_next[ - dimension * (slice(None, None, None),) - + (slice(None, rem, None), Ellipsis) - ], - ), - dimension, - ).unfold(dimension, size, step) - ret_larray = torch.cat((unfold_loc, unfold_halo), dimension) - ret = factories.array(ret_larray, is_split=dimension, device=dev, comm=comm) - - return ret - - # tests n = 20 @@ -99,9 +29,9 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): u = x.unfold(0, 3, 3) u = u.unfold(1, 3, 3) u = factories.array(u) -v = unfold(y, 0, 3, 3) +v = ht.unfold(y, 0, 3, 3) v.resplit_(1) -v = unfold(v, 1, 3, 3) +v = ht.unfold(v, 1, 3, 3) comm = u.comm # print(v) From e01daf007c2c0fa15456d90b5697907173d23dc1 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 2 Apr 2024 17:01:11 +0200 Subject: [PATCH 05/30] added test --- heat/core/tests/test_manipulations.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 9825d333e9..a0f8518720 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3751,3 +3751,19 @@ def test_vstack(self): b = ht.ones((12,), split=0) res = ht.vstack((a, b)) self.assertEqual(res.shape, (2, 12)) + + def test_unfold(self): + # 2D sliding views + n = 20 + + x = torch.arange(0, n * n).reshape((n, n)) + y = ht.array(x) + y.resplit_(0) + + u = x.unfold(0, 3, 3) + u = u.unfold(1, 3, 3) + u = ht.array(u) + v = ht.unfold(y, 0, 3, 3) + v = ht.unfold(v, 1, 3, 3) + + self.assertTrue(ht.equal(u, v)) From b33400210fbd26f578cd23ff5b85785859f1e46e Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Wed, 3 Apr 2024 17:15:28 +0200 Subject: [PATCH 06/30] fixed behavior for empty unfold_loc, exception handling for size - 1 > chunk_size more tests --- heat/core/manipulations.py | 22 ++++++-- heat/core/tests/test_manipulations.py | 36 ++++++++++++- unfold_test.py | 74 ++++++++++++++++++++++++++- 3 files changed, 124 insertions(+), 8 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 8dcb5a8d0c..aec425b9e2 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4193,6 +4193,10 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): raise ValueError("size must be >= 1.") if dimension < 0 or dimension >= a.ndim: raise ValueError(f"{dimension} is not a valid dimension of the given DNDarray.") + if size > a.shape[dimension]: # size too large + raise RuntimeError( + f"maximum size for DNDarray at dimension {dimension} is {a.shape[dimension]} but size is {size}." + ) comm = a.comm dev = a.device @@ -4214,6 +4218,8 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): # ret = ht.zeros(ret_shape, device=dev, split=a.split) # send the needed entries in the unfold dimension from node n to n+1 or n-1 + if (size - 1 > a.lshape_map[:, dimension]).any(): + raise ValueError("Chunk-size needs to be at least size - 1.") a.get_halo(size - 1) a_lshapes_cum = torch.hstack( [ @@ -4223,11 +4229,19 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): ) # min local index in unfold dimension min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[comm.rank] - unfold_loc = a.larray[ - dimension * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) - ].unfold(dimension, size, step) + loc_unfold_shape = list(a.lshape) + loc_unfold_shape[dimension] -= min_index + if loc_unfold_shape[dimension] >= size: # some unfold arrays are unfolds of the local array + unfold_loc = a.larray[ + dimension * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) + ].unfold(dimension, size, step) + else: + loc_unfold_shape[dimension] = 0 + unfold_loc = torch.zeros((*loc_unfold_shape, size)) ret_larray = unfold_loc - if comm.rank < comm.size - 1: # possibly unfold with halo from next rank + if ( + comm.rank < comm.size - 1 and min_index < a.lshape[dimension] + ): # possibly unfold with halo from next rank max_index = a.lshape[dimension] - min_index - 1 max_index = max_index // step * step + min_index # max local index in unfold dimension rem = max_index + size - a.lshape[dimension] diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index a0f8518720..0cdf42ef40 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3753,10 +3753,25 @@ def test_vstack(self): self.assertEqual(res.shape, (2, 12)) def test_unfold(self): + # exceptions + x = ht.arange(100) + with self.assertRaises(ValueError): + ht.unfold(x, -1, 1, 1) + with self.assertRaises(ValueError): + ht.unfold(x, 0, 0, 1) + with self.assertRaises(ValueError): + ht.unfold(x, 0, 1, 0) + with self.assertRaises(ValueError): # size too large for chunk_size + x.resplit_(0) + min_chunk_size = x.lshape_map[:, 0].min() + ht.unfold(x, 0, min_chunk_size + 2) + with self.assertRaises(RuntimeError): # size too large + ht.unfold(x, 0, 101, 1) + # 2D sliding views - n = 20 + n = 100 - x = torch.arange(0, n * n).reshape((n, n)) + x = torch.arange(n * n).reshape((n, n)) y = ht.array(x) y.resplit_(0) @@ -3767,3 +3782,20 @@ def test_unfold(self): v = ht.unfold(v, 1, 3, 3) self.assertTrue(ht.equal(u, v)) + + # more dimensions, different split axes + n = 10 + k = 5 # number of dimensions + shape = k * (n,) + size = n**k + + x = torch.arange(size).reshape(shape) + y = ht.array(x) + + for split in (None, *range(k)): + y.resplit_(split) + for dimension in range(k): + u = ht.array(x.unfold(dimension, 1, 1)) + v = ht.unfold(y, dimension, 1, 1) + + self.assertTrue(ht.equal(u, v)) diff --git a/unfold_test.py b/unfold_test.py index 47d1d9ce4e..40db94763b 100644 --- a/unfold_test.py +++ b/unfold_test.py @@ -10,6 +10,76 @@ from heat import DNDarray +def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): + """ + Returns a view of the original tensor which contains all slices of size size from self tensor in the dimension dimension. + + Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) + """ + if step < 1: + raise ValueError("step must be >= 1.") + if size < 1: + raise ValueError("size must be >= 1.") + if dimension < 0 or dimension >= a.ndim: + raise ValueError(f"{dimension} is not a valid dimension of the given DNDarray.") + + comm = a.comm + dev = a.device + tdev = dev.torch_device + + if a.split is None or comm.size == 1 or a.split != dimension: # early out + ret = factories.array( + a.larray.unfold(dimension, size, step), is_split=a.split, device=dev, comm=comm + ) + + return ret + else: # comm.size > 1 and split axis == unfold axis + # initialize the array + # a_shape = a.shape + # index range [0:sizedim-1-(size-1)] = [0:sizedim-size] + # --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1 + # ret_shape = (*a_shape[:dimension], int((a_shape[dimension]-size)/step) + 1, a_shape[dimension+1:], size) + + # ret = ht.zeros(ret_shape, device=dev, split=a.split) + + # send the needed entries in the unfold dimension from node n to n+1 or n-1 + a.get_halo(size - 1) + a_lshapes_cum = torch.hstack( + [ + torch.zeros(1, dtype=torch.int32, device=tdev), + torch.cumsum(a.lshape_map[:, dimension], 0), + ] + ) + # min local index in unfold dimension + min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[comm.rank] + unfold_loc = a.larray[ + dimension * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) + ].unfold(dimension, size, step) + ret_larray = unfold_loc + if comm.rank < comm.size - 1: # possibly unfold with halo from next rank + max_index = a.lshape[dimension] - min_index - 1 + max_index = max_index // step * step + min_index # max local index in unfold dimension + rem = max_index + size - a.lshape[dimension] + if rem > 0: # need data from halo + unfold_halo = torch.cat( + ( + a.larray[ + dimension * (slice(None, None, None),) + + (slice(max_index, None, None), Ellipsis) + ], + a.halo_next[ + dimension * (slice(None, None, None),) + + (slice(None, rem, None), Ellipsis) + ], + ), + dimension, + ).unfold(dimension, size, step) + ret_larray = torch.cat((unfold_loc, unfold_halo), dimension) + ret = factories.array(ret_larray, is_split=dimension, device=dev, comm=comm) + + return ret + + # tests n = 20 @@ -29,9 +99,9 @@ u = x.unfold(0, 3, 3) u = u.unfold(1, 3, 3) u = factories.array(u) -v = ht.unfold(y, 0, 3, 3) +v = unfold(y, 0, 3, 3) v.resplit_(1) -v = ht.unfold(v, 1, 3, 3) +v = unfold(v, 1, 3, 3) comm = u.comm # print(v) From 2e04c112d6079d307b1bbabc36ff7fffdb78c7b3 Mon Sep 17 00:00:00 2001 From: FOsterfeld <146953335+FOsterfeld@users.noreply.github.com> Date: Wed, 3 Apr 2024 17:44:23 +0200 Subject: [PATCH 07/30] wrong exception type in test --- heat/core/tests/test_manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 0cdf42ef40..8334613982 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3761,7 +3761,7 @@ def test_unfold(self): ht.unfold(x, 0, 0, 1) with self.assertRaises(ValueError): ht.unfold(x, 0, 1, 0) - with self.assertRaises(ValueError): # size too large for chunk_size + with self.assertRaises(RuntimeError): # size too large for chunk_size x.resplit_(0) min_chunk_size = x.lshape_map[:, 0].min() ht.unfold(x, 0, min_chunk_size + 2) From 4e9bbe264aa068f1d7d35c1d9d5d016823468739 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Mon, 8 Apr 2024 18:03:45 +0200 Subject: [PATCH 08/30] fixed wrong exception type in tests --- heat/core/tests/test_manipulations.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 8334613982..0cd5e7dd98 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3754,19 +3754,20 @@ def test_vstack(self): def test_unfold(self): # exceptions - x = ht.arange(100) + n = 1000 + x = ht.arange(n) with self.assertRaises(ValueError): ht.unfold(x, -1, 1, 1) with self.assertRaises(ValueError): ht.unfold(x, 0, 0, 1) with self.assertRaises(ValueError): ht.unfold(x, 0, 1, 0) - with self.assertRaises(RuntimeError): # size too large for chunk_size + with self.assertRaises(ValueError): # size too large for chunk_size x.resplit_(0) min_chunk_size = x.lshape_map[:, 0].min() ht.unfold(x, 0, min_chunk_size + 2) with self.assertRaises(RuntimeError): # size too large - ht.unfold(x, 0, 101, 1) + ht.unfold(x, 0, n + 1, 1) # 2D sliding views n = 100 From c28b99c360e3a8dd9225d0599bce95eb2de7748d Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Mon, 8 Apr 2024 18:26:07 +0200 Subject: [PATCH 09/30] fixed test for single node setting --- heat/core/manipulations.py | 11 +-- heat/core/tests/test_manipulations.py | 14 +-- unfold_test.py | 119 -------------------------- 3 files changed, 12 insertions(+), 132 deletions(-) delete mode 100644 unfold_test.py diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index aec425b9e2..407aba0652 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4193,8 +4193,8 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): raise ValueError("size must be >= 1.") if dimension < 0 or dimension >= a.ndim: raise ValueError(f"{dimension} is not a valid dimension of the given DNDarray.") - if size > a.shape[dimension]: # size too large - raise RuntimeError( + if size > a.shape[dimension]: # size too large, runtime error or value error? + raise ValueError( f"maximum size for DNDarray at dimension {dimension} is {a.shape[dimension]} but size is {size}." ) @@ -4209,17 +4209,12 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): return ret else: # comm.size > 1 and split axis == unfold axis - # initialize the array - # a_shape = a.shape # index range [0:sizedim-1-(size-1)] = [0:sizedim-size] # --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1 # ret_shape = (*a_shape[:dimension], int((a_shape[dimension]-size)/step) + 1, a_shape[dimension+1:], size) - # ret = ht.zeros(ret_shape, device=dev, split=a.split) - - # send the needed entries in the unfold dimension from node n to n+1 or n-1 if (size - 1 > a.lshape_map[:, dimension]).any(): - raise ValueError("Chunk-size needs to be at least size - 1.") + raise RuntimeError("Chunk-size needs to be at least size - 1.") a.get_halo(size - 1) a_lshapes_cum = torch.hstack( [ diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 0cd5e7dd98..09765ac0da 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3762,11 +3762,15 @@ def test_unfold(self): ht.unfold(x, 0, 0, 1) with self.assertRaises(ValueError): ht.unfold(x, 0, 1, 0) - with self.assertRaises(ValueError): # size too large for chunk_size - x.resplit_(0) - min_chunk_size = x.lshape_map[:, 0].min() - ht.unfold(x, 0, min_chunk_size + 2) - with self.assertRaises(RuntimeError): # size too large + x.resplit_(0) + min_chunk_size = x.lshape_map[:, 0].min() + if min_chunk_size + 2 > n: # size too large + with self.assertRaises(ValueError): + ht.unfold(x, 0, min_chunk_size + 2) + else: # size too large for chunk_size + with self.assertRaises(RuntimeError): + ht.unfold(x, 0, min_chunk_size + 2) + with self.assertRaises(ValueError): # size too large ht.unfold(x, 0, n + 1, 1) # 2D sliding views diff --git a/unfold_test.py b/unfold_test.py deleted file mode 100644 index 40db94763b..0000000000 --- a/unfold_test.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -Test module for DNDarray.unfold -""" - -import heat as ht -from mpi4py import MPI -import torch - -from heat import factories -from heat import DNDarray - - -def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): - """ - Returns a view of the original tensor which contains all slices of size size from self tensor in the dimension dimension. - - Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) - """ - if step < 1: - raise ValueError("step must be >= 1.") - if size < 1: - raise ValueError("size must be >= 1.") - if dimension < 0 or dimension >= a.ndim: - raise ValueError(f"{dimension} is not a valid dimension of the given DNDarray.") - - comm = a.comm - dev = a.device - tdev = dev.torch_device - - if a.split is None or comm.size == 1 or a.split != dimension: # early out - ret = factories.array( - a.larray.unfold(dimension, size, step), is_split=a.split, device=dev, comm=comm - ) - - return ret - else: # comm.size > 1 and split axis == unfold axis - # initialize the array - # a_shape = a.shape - # index range [0:sizedim-1-(size-1)] = [0:sizedim-size] - # --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1 - # ret_shape = (*a_shape[:dimension], int((a_shape[dimension]-size)/step) + 1, a_shape[dimension+1:], size) - - # ret = ht.zeros(ret_shape, device=dev, split=a.split) - - # send the needed entries in the unfold dimension from node n to n+1 or n-1 - a.get_halo(size - 1) - a_lshapes_cum = torch.hstack( - [ - torch.zeros(1, dtype=torch.int32, device=tdev), - torch.cumsum(a.lshape_map[:, dimension], 0), - ] - ) - # min local index in unfold dimension - min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[comm.rank] - unfold_loc = a.larray[ - dimension * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) - ].unfold(dimension, size, step) - ret_larray = unfold_loc - if comm.rank < comm.size - 1: # possibly unfold with halo from next rank - max_index = a.lshape[dimension] - min_index - 1 - max_index = max_index // step * step + min_index # max local index in unfold dimension - rem = max_index + size - a.lshape[dimension] - if rem > 0: # need data from halo - unfold_halo = torch.cat( - ( - a.larray[ - dimension * (slice(None, None, None),) - + (slice(max_index, None, None), Ellipsis) - ], - a.halo_next[ - dimension * (slice(None, None, None),) - + (slice(None, rem, None), Ellipsis) - ], - ), - dimension, - ).unfold(dimension, size, step) - ret_larray = torch.cat((unfold_loc, unfold_halo), dimension) - ret = factories.array(ret_larray, is_split=dimension, device=dev, comm=comm) - - return ret - - -# tests -n = 20 - -# x = torch.arange(0.0, n) -# y = 2 * x.unsqueeze(1) + x.unsqueeze(0) -# y = factories.array(y) -# y.resplit_(1) - -# print(y) -# print(unfold(y, 0, 3, 1)) - -x = torch.arange(0, n * n).reshape((n, n)) -# print(f"x: {x}") -y = factories.array(x) -y.resplit_(0) - -u = x.unfold(0, 3, 3) -u = u.unfold(1, 3, 3) -u = factories.array(u) -v = unfold(y, 0, 3, 3) -v.resplit_(1) -v = unfold(v, 1, 3, 3) - -comm = u.comm -# print(v) -equal = ht.equal(u, v) - -u_shape = u.shape -v_shape = v.shape - -if comm.rank == 0: - print(f"u.shape: {u_shape}") - print(f"v.shape: {v_shape}") - print(f"torch and heat unfold equal: {equal}") - # print(f"u: {u}") - -# print(f"v: {v}") From f67ef7e1c5e8ab89f9c53ccfe03f28fbbc430d64 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Wed, 10 Apr 2024 12:02:34 +0200 Subject: [PATCH 10/30] added better docstring --- heat/core/manipulations.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 407aba0652..e994a8d4d6 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4183,9 +4183,24 @@ def mpi_topk(a, b, mpi_type): def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): """ - Returns a view of the original tensor which contains all slices of size size from self tensor in the dimension dimension. + Returns a DNDarray which contains all slices of size size in the dimension dimension. Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) + + Parameters + ---------- + a : DNDarray + array to unfold + dimension : int + dimension in which unfolding happens + size : int + the size of each slice that is unfolded + step : int + the step between each slice + + Note + --------- + You have to make sure that every node has at least chunk size size-1 if the split dimension of the array is the unfold dimension. """ if step < 1: raise ValueError("step must be >= 1.") From b40a71553279f7b12cd7eeb2860c3d74b52be337 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Wed, 10 Apr 2024 12:40:42 +0200 Subject: [PATCH 11/30] added test to cover case that there are no fully local unfolds for a node --- heat/core/tests/test_manipulations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 09765ac0da..afd667ab1b 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3763,7 +3763,7 @@ def test_unfold(self): with self.assertRaises(ValueError): ht.unfold(x, 0, 1, 0) x.resplit_(0) - min_chunk_size = x.lshape_map[:, 0].min() + min_chunk_size = x.lshape_map[:, 0].min().item() if min_chunk_size + 2 > n: # size too large with self.assertRaises(ValueError): ht.unfold(x, 0, min_chunk_size + 2) @@ -3772,6 +3772,7 @@ def test_unfold(self): ht.unfold(x, 0, min_chunk_size + 2) with self.assertRaises(ValueError): # size too large ht.unfold(x, 0, n + 1, 1) + ht.unfold(x, 0, min_chunk_size, min_chunk_size // 2) # 2D sliding views n = 100 From 713e2ad6bb4dca605812cdc6f542ac6598a9b31a Mon Sep 17 00:00:00 2001 From: FOsterfeld <146953335+FOsterfeld@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:04:08 +0200 Subject: [PATCH 12/30] fixed test case of no fully local unfolds --- heat/core/tests/test_manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index afd667ab1b..204f90493a 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3772,7 +3772,7 @@ def test_unfold(self): ht.unfold(x, 0, min_chunk_size + 2) with self.assertRaises(ValueError): # size too large ht.unfold(x, 0, n + 1, 1) - ht.unfold(x, 0, min_chunk_size, min_chunk_size // 2) + ht.unfold(x, 0, min_chunk_size, min_chunk_size + 1) # no fully local unfolds on some nodes # 2D sliding views n = 100 From b323da8b058ccc23f01d659e86030eb52f07f2ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Apr 2024 11:04:34 +0000 Subject: [PATCH 13/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- heat/core/tests/test_manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 204f90493a..6cd05ee984 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3772,7 +3772,7 @@ def test_unfold(self): ht.unfold(x, 0, min_chunk_size + 2) with self.assertRaises(ValueError): # size too large ht.unfold(x, 0, n + 1, 1) - ht.unfold(x, 0, min_chunk_size, min_chunk_size + 1) # no fully local unfolds on some nodes + ht.unfold(x, 0, min_chunk_size, min_chunk_size + 1) # no fully local unfolds on some nodes # 2D sliding views n = 100 From d833a77959761c6d7c17e6de6a480a31976143a3 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 16 Apr 2024 10:24:54 +0200 Subject: [PATCH 14/30] fixed error due to unspecified torch device --- heat/core/manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index e994a8d4d6..167dc8b98b 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4247,7 +4247,7 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): ].unfold(dimension, size, step) else: loc_unfold_shape[dimension] = 0 - unfold_loc = torch.zeros((*loc_unfold_shape, size)) + unfold_loc = torch.zeros((*loc_unfold_shape, size), device=tdev) ret_larray = unfold_loc if ( comm.rank < comm.size - 1 and min_index < a.lshape[dimension] From 6c7ad848ecf9e62d306f1210571722206ebd721c Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Sat, 1 Jun 2024 12:11:30 +0200 Subject: [PATCH 15/30] added tests with different datatypes --- heat/core/tests/test_manipulations.py | 107 ++++++++++++++------------ 1 file changed, 56 insertions(+), 51 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 6cd05ee984..b4b3fe2d1f 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3753,55 +3753,60 @@ def test_vstack(self): self.assertEqual(res.shape, (2, 12)) def test_unfold(self): - # exceptions - n = 1000 - x = ht.arange(n) - with self.assertRaises(ValueError): - ht.unfold(x, -1, 1, 1) - with self.assertRaises(ValueError): - ht.unfold(x, 0, 0, 1) - with self.assertRaises(ValueError): - ht.unfold(x, 0, 1, 0) - x.resplit_(0) - min_chunk_size = x.lshape_map[:, 0].min().item() - if min_chunk_size + 2 > n: # size too large + dtypes = (ht.int, ht.float) + + for dtype in dtypes: # test with different datatypes + # exceptions + n = 1000 + x = ht.arange(n, dtype=dtype) + with self.assertRaises(ValueError): + ht.unfold(x, -1, 1, 1) + with self.assertRaises(ValueError): + ht.unfold(x, 0, 0, 1) with self.assertRaises(ValueError): - ht.unfold(x, 0, min_chunk_size + 2) - else: # size too large for chunk_size - with self.assertRaises(RuntimeError): - ht.unfold(x, 0, min_chunk_size + 2) - with self.assertRaises(ValueError): # size too large - ht.unfold(x, 0, n + 1, 1) - ht.unfold(x, 0, min_chunk_size, min_chunk_size + 1) # no fully local unfolds on some nodes - - # 2D sliding views - n = 100 - - x = torch.arange(n * n).reshape((n, n)) - y = ht.array(x) - y.resplit_(0) - - u = x.unfold(0, 3, 3) - u = u.unfold(1, 3, 3) - u = ht.array(u) - v = ht.unfold(y, 0, 3, 3) - v = ht.unfold(v, 1, 3, 3) - - self.assertTrue(ht.equal(u, v)) - - # more dimensions, different split axes - n = 10 - k = 5 # number of dimensions - shape = k * (n,) - size = n**k - - x = torch.arange(size).reshape(shape) - y = ht.array(x) - - for split in (None, *range(k)): - y.resplit_(split) - for dimension in range(k): - u = ht.array(x.unfold(dimension, 1, 1)) - v = ht.unfold(y, dimension, 1, 1) - - self.assertTrue(ht.equal(u, v)) + ht.unfold(x, 0, 1, 0) + x.resplit_(0) + min_chunk_size = x.lshape_map[:, 0].min().item() + if min_chunk_size + 2 > n: # size too large + with self.assertRaises(ValueError): + ht.unfold(x, 0, min_chunk_size + 2) + else: # size too large for chunk_size + with self.assertRaises(RuntimeError): + ht.unfold(x, 0, min_chunk_size + 2) + with self.assertRaises(ValueError): # size too large + ht.unfold(x, 0, n + 1, 1) + ht.unfold( + x, 0, min_chunk_size, min_chunk_size + 1 + ) # no fully local unfolds on some nodes + + # 2D sliding views + n = 100 + + x = torch.arange(n * n).reshape((n, n)) + y = ht.array(x, dtype) + y.resplit_(0) + + u = x.unfold(0, 3, 3) + u = u.unfold(1, 3, 3) + u = ht.array(u) + v = ht.unfold(y, 0, 3, 3) + v = ht.unfold(v, 1, 3, 3) + + self.assertTrue(ht.equal(u, v)) + + # more dimensions, different split axes + n = 10 + k = 5 # number of dimensions + shape = k * (n,) + size = n**k + + x = torch.arange(size).reshape(shape) + y = ht.array(x, dtype) + + for split in (None, *range(k)): + y.resplit_(split) + for dimension in range(k): + u = ht.array(x.unfold(dimension, 1, 1)) + v = ht.unfold(y, dimension, 1, 1) + + self.assertTrue(ht.equal(u, v)) From 9ad8d69e69a0f2559c9d66492d3870d37f6067b9 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 4 Jun 2024 11:26:58 +0200 Subject: [PATCH 16/30] =?UTF-8?q?renamed=20=C2=B4dimension=C2=B4=20to=20?= =?UTF-8?q?=C2=B4axis=C2=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- heat/core/manipulations.py | 61 +++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 167dc8b98b..92f44e946a 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4181,9 +4181,9 @@ def mpi_topk(a, b, mpi_type): MPI_TOPK = MPI.Op.Create(mpi_topk, commute=True) -def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): +def unfold(a: DNDarray, axis: int, size: int, step: int = 1): """ - Returns a DNDarray which contains all slices of size size in the dimension dimension. + Returns a DNDarray which contains all slices of size size in the axis axis. Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) @@ -4191,8 +4191,8 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): ---------- a : DNDarray array to unfold - dimension : int - dimension in which unfolding happens + axis : int + axis in which unfolding happens size : int the size of each slice that is unfolded step : int @@ -4200,76 +4200,75 @@ def unfold(a: DNDarray, dimension: int, size: int, step: int = 1): Note --------- - You have to make sure that every node has at least chunk size size-1 if the split dimension of the array is the unfold dimension. + You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis. """ if step < 1: raise ValueError("step must be >= 1.") if size < 1: raise ValueError("size must be >= 1.") - if dimension < 0 or dimension >= a.ndim: - raise ValueError(f"{dimension} is not a valid dimension of the given DNDarray.") - if size > a.shape[dimension]: # size too large, runtime error or value error? + if axis < 0 or axis >= a.ndim: + raise ValueError(f"{axis} is not a valid axis of the given DNDarray.") + if size > a.shape[axis]: # size too large, runtime error or value error? raise ValueError( - f"maximum size for DNDarray at dimension {dimension} is {a.shape[dimension]} but size is {size}." + f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}." ) comm = a.comm dev = a.device tdev = dev.torch_device - if a.split is None or comm.size == 1 or a.split != dimension: # early out + if a.split is None or comm.size == 1 or a.split != axis: # early out ret = factories.array( - a.larray.unfold(dimension, size, step), is_split=a.split, device=dev, comm=comm + a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm ) return ret else: # comm.size > 1 and split axis == unfold axis # index range [0:sizedim-1-(size-1)] = [0:sizedim-size] # --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1 - # ret_shape = (*a_shape[:dimension], int((a_shape[dimension]-size)/step) + 1, a_shape[dimension+1:], size) + # ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size) - if (size - 1 > a.lshape_map[:, dimension]).any(): + if (size - 1 > a.lshape_map[:, axis]).any(): raise RuntimeError("Chunk-size needs to be at least size - 1.") a.get_halo(size - 1) a_lshapes_cum = torch.hstack( [ torch.zeros(1, dtype=torch.int32, device=tdev), - torch.cumsum(a.lshape_map[:, dimension], 0), + torch.cumsum(a.lshape_map[:, axis], 0), ] ) - # min local index in unfold dimension + # min local index in unfold axis min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[comm.rank] loc_unfold_shape = list(a.lshape) - loc_unfold_shape[dimension] -= min_index - if loc_unfold_shape[dimension] >= size: # some unfold arrays are unfolds of the local array + loc_unfold_shape[axis] -= min_index + if loc_unfold_shape[axis] >= size: # some unfold arrays are unfolds of the local array unfold_loc = a.larray[ - dimension * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) - ].unfold(dimension, size, step) + axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) + ].unfold(axis, size, step) else: - loc_unfold_shape[dimension] = 0 + loc_unfold_shape[axis] = 0 unfold_loc = torch.zeros((*loc_unfold_shape, size), device=tdev) ret_larray = unfold_loc if ( - comm.rank < comm.size - 1 and min_index < a.lshape[dimension] + comm.rank < comm.size - 1 and min_index < a.lshape[axis] ): # possibly unfold with halo from next rank - max_index = a.lshape[dimension] - min_index - 1 - max_index = max_index // step * step + min_index # max local index in unfold dimension - rem = max_index + size - a.lshape[dimension] + max_index = a.lshape[axis] - min_index - 1 + max_index = max_index // step * step + min_index # max local index in unfold axis + rem = max_index + size - a.lshape[axis] if rem > 0: # need data from halo unfold_halo = torch.cat( ( a.larray[ - dimension * (slice(None, None, None),) + axis * (slice(None, None, None),) + (slice(max_index, None, None), Ellipsis) ], a.halo_next[ - dimension * (slice(None, None, None),) - + (slice(None, rem, None), Ellipsis) + axis * (slice(None, None, None),) + (slice(None, rem, None), Ellipsis) ], ), - dimension, - ).unfold(dimension, size, step) - ret_larray = torch.cat((unfold_loc, unfold_halo), dimension) - ret = factories.array(ret_larray, is_split=dimension, device=dev, comm=comm) + axis, + ).unfold(axis, size, step) + ret_larray = torch.cat((unfold_loc, unfold_halo), axis) + ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm) return ret From 778eb33bbc71abec29516c062aecb6e32a16f1f7 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 4 Jun 2024 11:33:40 +0200 Subject: [PATCH 17/30] =?UTF-8?q?use=20`DNDarray.counts=5Fdispls()=C2=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- heat/core/manipulations.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 92f44e946a..7ea69cb0d6 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4231,14 +4231,12 @@ def unfold(a: DNDarray, axis: int, size: int, step: int = 1): if (size - 1 > a.lshape_map[:, axis]).any(): raise RuntimeError("Chunk-size needs to be at least size - 1.") a.get_halo(size - 1) - a_lshapes_cum = torch.hstack( - [ - torch.zeros(1, dtype=torch.int32, device=tdev), - torch.cumsum(a.lshape_map[:, axis], 0), - ] - ) + + counts, displs = a.counts_displs() + displs = torch.tensor(displs, device=tdev) + # min local index in unfold axis - min_index = ((a_lshapes_cum[comm.rank] - 1) // step + 1) * step - a_lshapes_cum[comm.rank] + min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank] loc_unfold_shape = list(a.lshape) loc_unfold_shape[axis] -= min_index if loc_unfold_shape[axis] >= size: # some unfold arrays are unfolds of the local array From c36d6a835198c38a59d84ecdb3c17154e790d2ab Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 4 Jun 2024 12:01:43 +0200 Subject: [PATCH 18/30] updated docstring --- heat/core/manipulations.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 7ea69cb0d6..e907a9b195 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4183,7 +4183,7 @@ def mpi_topk(a, b, mpi_type): def unfold(a: DNDarray, axis: int, size: int, step: int = 1): """ - Returns a DNDarray which contains all slices of size size in the axis axis. + Returns a DNDarray which contains all slices of size `size` in the axis `axis`. Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) @@ -4198,6 +4198,24 @@ def unfold(a: DNDarray, axis: int, size: int, step: int = 1): step : int the step between each slice + Example: + ``` + >>> x = ht.arange(1., 8) + >>> x + DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e) + >>> ht.unfold(x, 0, 2, 1) + DNDarray([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.], + [5., 6.], + [6., 7.]], dtype=ht.float32, device=cpu:0, split=None) + >>> ht.unfold(x, 0, 2, 2) + DNDarray([[1., 2.], + [3., 4.], + [5., 6.]], dtype=ht.float32, device=cpu:0, split=None) + ``` + Note --------- You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis. From 5cdd98663c324e229bd170c514bd2b44f7f0829a Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 4 Jun 2024 12:16:07 +0200 Subject: [PATCH 19/30] use sanitize_axis --- heat/core/manipulations.py | 5 ++--- heat/core/tests/test_manipulations.py | 2 -- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index e907a9b195..6fb93f62cb 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4224,9 +4224,8 @@ def unfold(a: DNDarray, axis: int, size: int, step: int = 1): raise ValueError("step must be >= 1.") if size < 1: raise ValueError("size must be >= 1.") - if axis < 0 or axis >= a.ndim: - raise ValueError(f"{axis} is not a valid axis of the given DNDarray.") - if size > a.shape[axis]: # size too large, runtime error or value error? + axis = stride_tricks.sanitize_axis(a.shape, axis) + if size > a.shape[axis]: raise ValueError( f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}." ) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index b4b3fe2d1f..a6dbac036f 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3759,8 +3759,6 @@ def test_unfold(self): # exceptions n = 1000 x = ht.arange(n, dtype=dtype) - with self.assertRaises(ValueError): - ht.unfold(x, -1, 1, 1) with self.assertRaises(ValueError): ht.unfold(x, 0, 0, 1) with self.assertRaises(ValueError): From 63aa6862023fc0a68c5431ea765e2dc66df412b7 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 4 Jun 2024 14:21:03 +0200 Subject: [PATCH 20/30] support one-sided halo --- heat/core/dndarray.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 07b4e48418..a220538cc8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -384,7 +384,7 @@ def __prephalo(self, start, end) -> torch.Tensor: return self.__array[ix].clone().contiguous() - def get_halo(self, halo_size: int) -> torch.Tensor: + def get_halo(self, halo_size: int, prev: bool = True, next: bool = True) -> torch.Tensor: """ Fetch halos of size ``halo_size`` from neighboring ranks and save them in ``self.halo_next/self.halo_prev``. @@ -392,6 +392,10 @@ def get_halo(self, halo_size: int) -> torch.Tensor: ---------- halo_size : int Size of the halo. + prev : bool, optional + If True, fetch the halo from the previous rank. Default: True. + next : bool, optional + If True, fetch the halo from the next rank. Default: True. """ if not isinstance(halo_size, int): raise TypeError( @@ -433,25 +437,29 @@ def get_halo(self, halo_size: int) -> torch.Tensor: req_list = [] # exchange data with next populated process - if rank != last_rank: - self.comm.Isend(a_next, next_rank) - res_prev = torch.zeros( - a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device - ) - req_list.append(self.comm.Irecv(res_prev, source=next_rank)) + if prev: + if rank != last_rank: + self.comm.Isend(a_next, next_rank) + if rank != first_rank: + res_prev = torch.zeros( + a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device + ) + req_list.append(self.comm.Irecv(res_prev, source=prev_rank)) - if rank != first_rank: - self.comm.Isend(a_prev, prev_rank) - res_next = torch.zeros( - a_next.size(), dtype=a_next.dtype, device=self.device.torch_device - ) - req_list.append(self.comm.Irecv(res_next, source=prev_rank)) + if next: + if rank != first_rank: + self.comm.Isend(a_prev, prev_rank) + if rank != last_rank: + res_next = torch.zeros( + a_next.size(), dtype=a_next.dtype, device=self.device.torch_device + ) + req_list.append(self.comm.Irecv(res_next, source=next_rank)) for req in req_list: req.Wait() - self.__halo_next = res_prev - self.__halo_prev = res_next + self.__halo_next = res_next + self.__halo_prev = res_prev self.__ishalo = True def __cat_halo(self) -> torch.Tensor: From 04d1217b0ec3d75abc95a4a7f4ef77c0aaa409f5 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Wed, 5 Jun 2024 13:36:36 +0200 Subject: [PATCH 21/30] use `DNDarray.array_with_halos` --- heat/core/manipulations.py | 39 ++++++++++---------------------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 6fb93f62cb..afbce869c3 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4247,43 +4247,24 @@ def unfold(a: DNDarray, axis: int, size: int, step: int = 1): if (size - 1 > a.lshape_map[:, axis]).any(): raise RuntimeError("Chunk-size needs to be at least size - 1.") - a.get_halo(size - 1) + a.get_halo(size - 1, prev=False) counts, displs = a.counts_displs() displs = torch.tensor(displs, device=tdev) # min local index in unfold axis min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank] - loc_unfold_shape = list(a.lshape) - loc_unfold_shape[axis] -= min_index - if loc_unfold_shape[axis] >= size: # some unfold arrays are unfolds of the local array - unfold_loc = a.larray[ + if min_index >= a.lshape[axis] or ( + comm.rank == comm.size - 1 and min_index + size >= a.lshape[axis] + ): + loc_unfold_shape = list(a.lshape) + loc_unfold_shape[axis] = 0 + ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev) + else: # unfold has local data + ret_larray = a.array_with_halos[ axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) ].unfold(axis, size, step) - else: - loc_unfold_shape[axis] = 0 - unfold_loc = torch.zeros((*loc_unfold_shape, size), device=tdev) - ret_larray = unfold_loc - if ( - comm.rank < comm.size - 1 and min_index < a.lshape[axis] - ): # possibly unfold with halo from next rank - max_index = a.lshape[axis] - min_index - 1 - max_index = max_index // step * step + min_index # max local index in unfold axis - rem = max_index + size - a.lshape[axis] - if rem > 0: # need data from halo - unfold_halo = torch.cat( - ( - a.larray[ - axis * (slice(None, None, None),) - + (slice(max_index, None, None), Ellipsis) - ], - a.halo_next[ - axis * (slice(None, None, None),) + (slice(None, rem, None), Ellipsis) - ], - ), - axis, - ).unfold(axis, size, step) - ret_larray = torch.cat((unfold_loc, unfold_halo), axis) + ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm) return ret From dca31f07ae75c324fc73e1b7ad55c9da3a8a1a2b Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Fri, 5 Jul 2024 16:51:28 +0200 Subject: [PATCH 22/30] fixed condition for empty local unfold data --- heat/core/manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index afbce869c3..3372a79044 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4255,7 +4255,7 @@ def unfold(a: DNDarray, axis: int, size: int, step: int = 1): # min local index in unfold axis min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank] if min_index >= a.lshape[axis] or ( - comm.rank == comm.size - 1 and min_index + size >= a.lshape[axis] + comm.rank == comm.size - 1 and min_index + size > a.lshape[axis] ): loc_unfold_shape = list(a.lshape) loc_unfold_shape[axis] = 0 From 82d83eae756efcc55efcd5e8c80ea6551170c1ac Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Fri, 5 Jul 2024 17:16:29 +0200 Subject: [PATCH 23/30] more tests --- heat/core/tests/test_manipulations.py | 33 ++++++++++++++++++++------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index a6dbac036f..867e83128b 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3793,18 +3793,35 @@ def test_unfold(self): self.assertTrue(ht.equal(u, v)) # more dimensions, different split axes - n = 10 - k = 5 # number of dimensions + n = 53 + k = 3 # number of dimensions shape = k * (n,) size = n**k x = torch.arange(size).reshape(shape) - y = ht.array(x, dtype) + _y = x.clone() + y = ht.array(_y, dtype) for split in (None, *range(k)): y.resplit_(split) - for dimension in range(k): - u = ht.array(x.unfold(dimension, 1, 1)) - v = ht.unfold(y, dimension, 1, 1) - - self.assertTrue(ht.equal(u, v)) + for size in range(2, 9): + for step in range(1, 21): + for dimension in range(k): + u = ht.array(x.unfold(dimension, size, step)) + v = ht.unfold(y, dimension, size, step) + + equal = ht.equal(u, v) + if not equal: + diff = (u - v).abs().max().item() + diff_indices = ht.nonzero((u - v).resplit_(None)).resplit_(None) + if y.comm.rank == 0: + print( + f"\ndtype: {dtype}\nsplit: {split}\ndimension: {dimension}\nstep: {step}\nsize: {size}" + ) + # print(f"u.shape: {u_shape}") + # print(f"v.shape: {v_shape}") + print(f"diff: {diff}") + print(f"displs: {y.counts_displs()[1]}") + print(f"diff_indices: {diff_indices}") + + self.assertTrue(ht.equal(u, v)) From 562d9a00b1479f3d317e74398ea64b5acd6c694a Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Fri, 5 Jul 2024 23:37:48 +0200 Subject: [PATCH 24/30] detach after cloning --- heat/core/tests/test_manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 867e83128b..a16409eb87 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3799,7 +3799,7 @@ def test_unfold(self): size = n**k x = torch.arange(size).reshape(shape) - _y = x.clone() + _y = x.clone().detach() y = ht.array(_y, dtype) for split in (None, *range(k)): From 825979ce6f924fc29ada6634cdfc372f2c873b9e Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Fri, 5 Jul 2024 23:40:06 +0200 Subject: [PATCH 25/30] test: blocking send in get_halo() --- heat/core/dndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index bffc61c364..5300b1bfa3 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -448,7 +448,7 @@ def get_halo(self, halo_size: int, prev: bool = True, next: bool = True) -> torc if next: if rank != first_rank: - self.comm.Isend(a_prev, prev_rank) + self.comm.Send(a_prev, prev_rank) if rank != last_rank: res_next = torch.zeros( a_next.size(), dtype=a_next.dtype, device=self.device.torch_device From 3d892d1927b4a59d1a8e23ec207261aca3434b0c Mon Sep 17 00:00:00 2001 From: Hoppe Date: Mon, 15 Jul 2024 12:17:36 +0200 Subject: [PATCH 26/30] replaced Send by Isend in "next" --- heat/core/dndarray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 5300b1bfa3..e1dc38390c 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -448,7 +448,8 @@ def get_halo(self, halo_size: int, prev: bool = True, next: bool = True) -> torc if next: if rank != first_rank: - self.comm.Send(a_prev, prev_rank) + # self.comm.Send(a_prev, prev_rank) + req_list.append(self.comm.Isend(a_prev, prev_rank)) if rank != last_rank: res_next = torch.zeros( a_next.size(), dtype=a_next.dtype, device=self.device.torch_device From 6d92fb9164c642b277ae3d34147a0ecdd14ee972 Mon Sep 17 00:00:00 2001 From: Hoppe Date: Tue, 23 Jul 2024 14:06:56 +0200 Subject: [PATCH 27/30] int64 in batchparallel clustering predict --- heat/cluster/batchparallelclustering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/cluster/batchparallelclustering.py b/heat/cluster/batchparallelclustering.py index e935aa85d7..db250d2a9c 100644 --- a/heat/cluster/batchparallelclustering.py +++ b/heat/cluster/batchparallelclustering.py @@ -293,7 +293,7 @@ def predict(self, x: DNDarray): labels = DNDarray( local_labels, gshape=(x.shape[0], 1), - dtype=ht.int32, + dtype=ht.int64, device=x.device, comm=x.comm, split=x.split, From 76efe78e5eb631f9b04d08a2768a0c145ad462b6 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 23 Jul 2024 15:01:26 +0200 Subject: [PATCH 28/30] added error for size=1 --- heat/core/manipulations.py | 8 ++++---- heat/core/tests/test_manipulations.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index cf414e807f..5985df65e3 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4229,9 +4229,9 @@ def unfold(a: DNDarray, axis: int, size: int, step: int = 1): axis : int axis in which unfolding happens size : int - the size of each slice that is unfolded + the size of each slice that is unfolded, must be greater than 1 step : int - the step between each slice + the step between each slice, must be at least 1 Example: ``` @@ -4257,8 +4257,8 @@ def unfold(a: DNDarray, axis: int, size: int, step: int = 1): """ if step < 1: raise ValueError("step must be >= 1.") - if size < 1: - raise ValueError("size must be >= 1.") + if size <= 1: + raise ValueError("size must be > 1.") axis = stride_tricks.sanitize_axis(a.shape, axis) if size > a.shape[axis]: raise ValueError( diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index a16409eb87..0a69bce396 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3759,10 +3759,10 @@ def test_unfold(self): # exceptions n = 1000 x = ht.arange(n, dtype=dtype) - with self.assertRaises(ValueError): - ht.unfold(x, 0, 0, 1) - with self.assertRaises(ValueError): - ht.unfold(x, 0, 1, 0) + with self.assertRaises(ValueError): # size too small + ht.unfold(x, 0, 1, 1) + with self.assertRaises(ValueError): # step too small + ht.unfold(x, 0, 2, 0) x.resplit_(0) min_chunk_size = x.lshape_map[:, 0].min().item() if min_chunk_size + 2 > n: # size too large From e6ef047f3fef7f134690b8c8bd683c9365f0a4ef Mon Sep 17 00:00:00 2001 From: Fabian Hoppe <112093564+mrfh92@users.noreply.github.com> Date: Tue, 23 Jul 2024 16:01:24 +0200 Subject: [PATCH 29/30] Update batchparallelclustering.py Undid my stupid change before that belongs to another issue --- heat/cluster/batchparallelclustering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/cluster/batchparallelclustering.py b/heat/cluster/batchparallelclustering.py index db250d2a9c..e935aa85d7 100644 --- a/heat/cluster/batchparallelclustering.py +++ b/heat/cluster/batchparallelclustering.py @@ -293,7 +293,7 @@ def predict(self, x: DNDarray): labels = DNDarray( local_labels, gshape=(x.shape[0], 1), - dtype=ht.int64, + dtype=ht.int32, device=x.device, comm=x.comm, split=x.split, From d00569a367b764468e66283e2a56ad0c639e49c2 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Sun, 18 Aug 2024 00:25:42 +0200 Subject: [PATCH 30/30] Removed old/dead code, resolved review --- heat/core/dndarray.py | 1 - heat/core/tests/test_manipulations.py | 14 -------------- 2 files changed, 15 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index e1dc38390c..9d9bda1037 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -448,7 +448,6 @@ def get_halo(self, halo_size: int, prev: bool = True, next: bool = True) -> torc if next: if rank != first_rank: - # self.comm.Send(a_prev, prev_rank) req_list.append(self.comm.Isend(a_prev, prev_rank)) if rank != last_rank: res_next = torch.zeros( diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 0571c3a115..554293fa25 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3811,18 +3811,4 @@ def test_unfold(self): u = ht.array(x.unfold(dimension, size, step)) v = ht.unfold(y, dimension, size, step) - equal = ht.equal(u, v) - if not equal: - diff = (u - v).abs().max().item() - diff_indices = ht.nonzero((u - v).resplit_(None)).resplit_(None) - if y.comm.rank == 0: - print( - f"\ndtype: {dtype}\nsplit: {split}\ndimension: {dimension}\nstep: {step}\nsize: {size}" - ) - # print(f"u.shape: {u_shape}") - # print(f"v.shape: {v_shape}") - print(f"diff: {diff}") - print(f"displs: {y.counts_displs()[1]}") - print(f"diff_indices: {diff_indices}") - self.assertTrue(ht.equal(u, v))