Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce DNDarray.__array__() method #1154

Merged
merged 12 commits into from
May 22, 2023
13 changes: 10 additions & 3 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,12 @@ def __cat_halo(self) -> torch.Tensor:
dim=self.split,
)

def __array__(self) -> np.ndarray:
"""
Returns a view of the process-local slice of the :class:`DNDarray` as a numpy ndarray, if the ``DNDarray`` resides on CPU. Otherwise, it returns a copy, on CPU, of the process-local slice of ``DNDarray`` as numpy ndarray.
"""
return self.larray.cpu().__array__()

def astype(self, dtype, copy=True) -> DNDarray:
"""
Returns a casted version of this array.
Expand All @@ -472,7 +478,7 @@ def astype(self, dtype, copy=True) -> DNDarray:
Parameters
----------
dtype : datatype
HeAT type to which the array is cast
Heat type to which the array is cast
copy : bool, optional
By default the operation returns a copy of this array. If copy is set to ``False`` the cast is performed
in-place and this array is returned
Expand Down Expand Up @@ -1111,8 +1117,9 @@ def __len__(self) -> int:

def numpy(self) -> np.array:
"""
Convert :class:`DNDarray` to numpy array. If the ``DNDarray`` is distributed it will be merged beforehand. If the ``DNDarray``
resides on the GPU, it will be copied to the CPU first.
Returns a copy of the :class:`DNDarray` as numpy ndarray. If the ``DNDarray`` resides on the GPU, the underlying data will be copied to the CPU first.

If the ``DNDarray`` is distributed, an MPI Allgather operation will be performed before converting to np.ndarray, i.e. each MPI process will end up holding a copy of the entire array in memory. Make sure process memory is sufficient!

Examples
--------
Expand Down
19 changes: 19 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,25 @@ def test_gethalo(self):
self.assertTrue(data.halo_prev is prev_halo or (data.halo_prev == prev_halo).all())
self.assertTrue(data.halo_next is next_halo or (data.halo_next == next_halo).all())

def test_array(self):
# undistributed case
x = ht.arange(6 * 7 * 8).reshape((6, 7, 8))
x_np = np.arange(6 * 7 * 8, dtype=np.int32).reshape((6, 7, 8))

self.assertTrue((x.__array__() == x_np).all())
self.assertIsInstance(x.__array__(), np.ndarray)
self.assertEqual(x.__array__().dtype, x_np.dtype)
self.assertEqual(x.__array__().shape, x.gshape)

# distributed case
x = ht.arange(6 * 7 * 8, dtype=ht.float64, split=0).reshape((6, 7, 8))
x_np = np.arange(6 * 7 * 8, dtype=np.float64).reshape((6, 7, 8))

self.assertTrue((x.__array__() == x.larray.cpu().numpy()).all())
self.assertIsInstance(x.__array__(), np.ndarray)
self.assertEqual(x.__array__().dtype, x_np.dtype)
self.assertEqual(x.__array__().shape, x.lshape)

def test_larray(self):
# undistributed case
x = ht.arange(6 * 7 * 8).reshape((6, 7, 8))
Expand Down