Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/1243 refactoring of communication separate mpi4py wrappers from dn darrays #1265

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions heat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from . import spatial
from . import utils
from . import preprocessing
from . import communication_backends
6 changes: 3 additions & 3 deletions heat/cluster/_kcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _initialize_cluster_centers(self, x: DNDarray):
if x.comm.rank == proc:
idx = sample - displ[proc]
xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
xi.comm.Bcast(xi, root=proc)
xi.comm.Bcast(xi.larray, root=proc)
centroids[i, :] = xi

else:
Expand Down Expand Up @@ -155,7 +155,7 @@ def _initialize_cluster_centers(self, x: DNDarray):
if x.comm.rank == proc:
idx = sample - displ[proc]
x0 = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
x0.comm.Bcast(x0, root=proc)
x0.comm.Bcast(x0.larray, root=proc)
centroids[0, :] = x0
for i in range(1, self.n_clusters):
distances = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True)
Expand All @@ -179,7 +179,7 @@ def _initialize_cluster_centers(self, x: DNDarray):
if x.comm.rank == proc:
idx = sample - displ[proc]
xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
xi.comm.Bcast(xi, root=proc)
xi.comm.Bcast(xi.larray, root=proc)
centroids[i, :] = xi

else:
Expand Down
2 changes: 1 addition & 1 deletion heat/cluster/kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):
if x.comm.rank == proc:
lidx = idx - displ[proc]
closest_point = ht.array(x.lloc[lidx, :], device=x.device, comm=x.comm)
closest_point.comm.Bcast(closest_point, root=proc)
closest_point.comm.Bcast(closest_point.larray, root=proc)
new_cluster_centers[i, :] = closest_point

return new_cluster_centers
Expand Down
6 changes: 6 additions & 0 deletions heat/communication_backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Add the communication_backends functions to the ht.communication_backends namespace
"""

from .communication import *
from .mpi4py4torch import *
169 changes: 169 additions & 0 deletions heat/communication_backends/communication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""
Module implementing the communication layer of HeAT
"""
from __future__ import annotations
import torch
from typing import Optional, Tuple
from ..core.stride_tricks import sanitize_axis


class Communication:
"""
Base class for Communications (inteded for other backends)
"""

@staticmethod
def is_distributed() -> NotImplementedError:
"""
Whether or not the Communication is distributed
"""
raise NotImplementedError()

def __init__(self) -> NotImplementedError:
raise NotImplementedError()

def chunk(
self,
shape: Tuple[int],
split: int,
rank: int = None,
w_size: int = None,
sparse: bool = False,
) -> Tuple[int, Tuple[int], Tuple[slice]]:
"""
Calculates the chunk of data that will be assigned to this compute node given a global data shape and a split
axis.
Returns ``(offset, local_shape, slices)``: the offset in the split dimension, the resulting local shape if the
global input shape is chunked on the split axis and the chunk slices with respect to the given shape

Parameters
----------
shape : Tuple[int,...]
The global shape of the data to be split
split : int
The axis along which to chunk the data
rank : int, optional
Process for which the chunking is calculated for, defaults to ``self.rank``.
Intended for creating chunk maps without communication
w_size : int, optional
The MPI world size, defaults to ``self.size``.
Intended for creating chunk maps without communication
sparse : bool, optional
Specifies whether the array is a sparse matrix
"""
# ensure the split axis is valid, we actually do not need it
split = sanitize_axis(shape, split)
if split is None:
return 0, shape, tuple(slice(0, end) for end in shape)
rank = self.rank if rank is None else rank
w_size = self.size if w_size is None else w_size
if not isinstance(rank, int) or not isinstance(w_size, int):
raise TypeError("rank and size must be integers")

dims = len(shape)
size = shape[split]
chunk = size // w_size
remainder = size % w_size

if remainder > rank:
chunk += 1
start = rank * chunk
else:
start = rank * chunk + remainder
end = start + chunk

if sparse:
return start, end

return (
start,
tuple(shape[i] if i != split else end - start for i in range(dims)),
tuple(slice(0, shape[i]) if i != split else slice(start, end) for i in range(dims)),
)

def counts_displs_shape(
self, shape: Tuple[int], axis: int
) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]:
"""
Calculates the item counts, displacements and output shape for a variable sized all-to-all MPI-call (e.g.
``MPI_Alltoallv``). The passed shape is regularly chunk along the given axis and for all nodes.

Parameters
----------
shape : Tuple[int,...]
The object for which to calculate the chunking.
axis : int
The axis along which the chunking is performed.

"""
# the elements send/received by all nodes
counts = torch.full((self.size,), shape[axis] // self.size)
counts[: shape[axis] % self.size] += 1

# the displacements into the buffer
displs = torch.zeros((self.size,), dtype=counts.dtype)
torch.cumsum(counts[:-1], out=displs[1:], dim=0)

# helper that calculates the output shape for a receiving buffer under the assumption all nodes have an equally
# sized input compared to this node
output_shape = list(shape)
output_shape[axis] = self.size * counts[self.rank].item()

return tuple(counts.tolist()), tuple(displs.tolist()), tuple(output_shape)


# creating a duplicate COMM
from .mpi4py4torch import MPICommunication
from mpi4py import MPI

comm = MPI.COMM_WORLD
dup_comm = comm.Dup()

MPI_WORLD = MPICommunication(dup_comm)
MPI_SELF = MPICommunication(MPI.COMM_SELF.Dup())

# set the default communicator to be MPI_WORLD
__default_comm = MPI_WORLD


def get_comm() -> Communication:
"""
Retrieves the currently globally set default communication.
"""
return __default_comm


def sanitize_comm(comm: Optional[Communication]) -> Communication:
"""
Sanitizes a device or device identifier, i.e. checks whether it is already an instance of :class:`heat.core.devices.Device`
or a string with known device identifier and maps it to a proper ``Device``.

Parameters
----------
comm : Communication
The comm to be sanitized

Raises
------
TypeError
If the given communication is not the proper type
"""
if comm is None:
return get_comm()
elif isinstance(comm, Communication):
return comm

raise TypeError(f"Unknown communication, must be instance of {Communication}")


def use_comm(comm: Communication = None):
"""
Sets the globally used default communicator.

Parameters
----------
comm : Communication or None
The communication to be set
"""
global __default_comm
__default_comm = sanitize_comm(comm)
Loading
Loading