diff --git a/ignite/distributed/auto.py b/ignite/distributed/auto.py index 49ae70cae29..cc888ce0346 100644 --- a/ignite/distributed/auto.py +++ b/ignite/distributed/auto.py @@ -1,4 +1,5 @@ import warnings +from typing import Any, Callable, Iterator, List, Optional, Union import torch import torch.nn as nn @@ -16,7 +17,7 @@ __all__ = ["auto_dataloader", "auto_model", "auto_optim", "DistributedProxySampler"] -def auto_dataloader(dataset, **kwargs): +def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, "_MpDeviceLoader"]: """Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from :meth:`~ignite.distributed.utils.available_backends()`). @@ -74,7 +75,9 @@ def auto_dataloader(dataset, **kwargs): if "batch_sampler" not in kwargs: if kwargs.get("sampler", None) is not None: - sampler = DistributedProxySampler(kwargs["sampler"], num_replicas=world_size, rank=rank) + sampler = DistributedProxySampler( + kwargs["sampler"], num_replicas=world_size, rank=rank + ) # type: Union[DistributedProxySampler, DistributedSampler, Sampler] else: sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=kwargs.get("shuffle", True) @@ -115,9 +118,9 @@ def auto_dataloader(dataset, **kwargs): except ImportError: pass - sampler = dataloader.sampler - dataloader = mp_device_loader_cls(dataloader, idist.device()) - dataloader.sampler = sampler + mp_dataloader = mp_device_loader_cls(dataloader, idist.device()) + mp_dataloader.sampler = dataloader.sampler # type: ignore[attr-defined] + return mp_dataloader return dataloader @@ -266,7 +269,7 @@ class DistributedProxySampler(DistributedSampler): """ - def __init__(self, sampler: Sampler, num_replicas=None, rank=None): + def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None, rank: Optional[int] = None) -> None: if not isinstance(sampler, Sampler): raise TypeError("Argument sampler should be instance of torch Sampler, but given: {}".format(type(sampler))) @@ -274,24 +277,26 @@ def __init__(self, sampler: Sampler, num_replicas=None, rank=None): if not hasattr(sampler, "__len__"): raise TypeError("Argument sampler should have length") - super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False) + super(DistributedProxySampler, self).__init__( + sampler, num_replicas=num_replicas, rank=rank, shuffle=False # type: ignore[arg-type] + ) self.sampler = sampler - def __iter__(self): + def __iter__(self) -> Iterator: # deterministically shuffle based on epoch - torch.manual_seed(self.epoch) + torch.manual_seed(self.epoch) # type: ignore[attr-defined] - indices = [] - while len(indices) < self.total_size: + indices = [] # type: List + while len(indices) < self.total_size: # type: ignore[attr-defined] indices += list(self.sampler) - if len(indices) > self.total_size: - indices = indices[: self.total_size] + if len(indices) > self.total_size: # type: ignore[attr-defined] + indices = indices[: self.total_size] # type: ignore[attr-defined] # subsample - indices = indices[self.rank : self.total_size : self.num_replicas] - if len(indices) != self.num_samples: - raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) + indices = indices[self.rank : self.total_size : self.num_replicas] # type: ignore[attr-defined] + if len(indices) != self.num_samples: # type: ignore[attr-defined] + raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) # type: ignore[attr-defined] return iter(indices) @@ -304,22 +309,22 @@ def __iter__(self): class _MpDeviceLoader: # https://github.com/pytorch/xla/pull/2117 # From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available - def __init__(self, loader, device, **kwargs): + def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None: self._loader = loader self._device = device self._parallel_loader_kwargs = kwargs - def __iter__(self): + def __iter__(self) -> Iterator: parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs) return parallel_loader.per_device_loader(self._device) - def __len__(self): + def __len__(self) -> int: return len(self._loader) class _XLADistributedOptimizer(Optimizer): - def __init__(self, optimizer): - super(self.__class__, self).__init__(optimizer.param_groups) + def __init__(self, optimizer: Optimizer) -> None: + super(self.__class__, self).__init__(optimizer.param_groups) # type: ignore[call-arg] self.wrapped_optimizer = optimizer - def step(self, closure=None): + def step(self, closure: Optional[Callable] = None) -> None: xm.optimizer_step(self.wrapped_optimizer, barrier=True) diff --git a/ignite/distributed/comp_models/__init__.py b/ignite/distributed/comp_models/__init__.py index 3001edcb067..c9227701078 100644 --- a/ignite/distributed/comp_models/__init__.py +++ b/ignite/distributed/comp_models/__init__.py @@ -4,7 +4,7 @@ from ignite.distributed.comp_models.xla import has_xla_support -def setup_available_computation_models(): +def setup_available_computation_models(): # type: ignore # inhomogeneous Tuple types are not supported models = [ _SerialModel, ] diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index cd3fad63043..f31b074c398 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -1,7 +1,7 @@ import warnings from abc import ABCMeta, abstractmethod from numbers import Number -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, cast import torch @@ -13,7 +13,7 @@ class ComputationModel(metaclass=ABCMeta): """ # this is an additional local rank storage used when idist is setup from existing native torch dist context - _ext_local_rank = None + _ext_local_rank = None # type: Optional[int] def __init__(self): self._backend = None @@ -21,7 +21,7 @@ def __init__(self): self._nnodes = None self._node = None - def _setup_attrs(self): + def _setup_attrs(self) -> None: if self._nproc_per_node is None: self._nproc_per_node = self._compute_nproc_per_node() if self.get_world_size() > 1 else 1 if self._nnodes is None: @@ -66,7 +66,7 @@ def backend(self) -> Optional[str]: pass @abstractmethod - def finalize(self): + def finalize(self) -> None: pass @staticmethod @@ -76,15 +76,15 @@ def create_from_context() -> Optional["ComputationModel"]: @staticmethod @abstractmethod - def create_from_backend(backend: str, **kwargs) -> "ComputationModel": + def create_from_backend(backend: str, **kwargs: Any) -> "ComputationModel": pass @staticmethod @abstractmethod - def spawn(*args, **kwargs): + def spawn(*args: Any, **kwargs: Any) -> None: pass - _collective_op_dtype = None + _collective_op_dtype = None # type: Any @staticmethod def _encode_str(x: str, device: torch.device) -> torch.Tensor: @@ -107,7 +107,9 @@ def _decode_str(xs: torch.Tensor) -> List[str]: out = [bytearray(x[: x[-1]].tolist()).decode("utf-8") for x in xs] return out - def _apply_op(self, tensor: torch.Tensor, device: torch.device, fn: Callable, *args, **kwargs) -> torch.Tensor: + def _apply_op( + self, tensor: torch.Tensor, device: torch.device, fn: Callable, *args: Any, **kwargs: Any + ) -> torch.Tensor: out_dtype = None tensor_device = None @@ -133,7 +135,7 @@ def _apply_op(self, tensor: torch.Tensor, device: torch.device, fn: Callable, *a return tensor def _collective_op( - self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args, **kwargs + self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any ) -> Union[torch.Tensor, Number, List[str]]: tensor_to_number = tensor_to_str = False device = self.device() @@ -147,7 +149,7 @@ def _collective_op( tensor = self._apply_op(tensor, device, fn, *args, **kwargs) if tensor_to_number and tensor.numel() == 1: - return tensor.item() + return cast(Number, tensor.item()) elif tensor_to_str: return self._decode_str(tensor) return tensor @@ -156,7 +158,7 @@ def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Un if not isinstance(tensor, (torch.Tensor, Number)): raise TypeError("Unhandled input type {}".format(type(tensor))) - return self._collective_op(tensor, self._do_all_reduce, op) + return cast(Union[torch.Tensor, Number], self._collective_op(tensor, self._do_all_reduce, op)) def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]: if not isinstance(tensor, (torch.Tensor, Number, str)): @@ -189,7 +191,7 @@ def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> U tensor = self._apply_op(tensor, device, self._do_broadcast, src) if tensor_to_number: - return tensor.item() + return cast(Number, tensor.item()) if tensor_to_str: list_str = self._decode_str(tensor) return list_str[0] @@ -208,7 +210,7 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: pass @abstractmethod - def barrier(self): + def barrier(self) -> None: pass @@ -219,6 +221,9 @@ class _SerialModel(ComputationModel): name = "serial" available_backends = () + def __init__(self, _backend: Optional[str] = None, **kwargs: Any) -> None: + super(_SerialModel, self).__init__() + def get_local_rank(self) -> int: return 0 @@ -242,10 +247,10 @@ def device(self) -> torch.device: return torch.device("cuda") return torch.device("cpu") - def backend(self) -> None: + def backend(self) -> Optional[str]: return None - def finalize(self): + def finalize(self) -> None: pass def _compute_nproc_per_node(self) -> int: @@ -256,17 +261,17 @@ def create_from_context() -> "_SerialModel": return _SerialModel() @staticmethod - def create_from_backend(backend: Optional[str] = None, **kwargs) -> "_SerialModel": + def create_from_backend(backend: Optional[str] = None, **kwargs: Any) -> "_SerialModel": return _SerialModel() @staticmethod - def spawn(*args, **kwargs): + def spawn(*args: Any, **kwargs: Any) -> None: raise NotImplementedError("Serial computation model does not implement spawn method") def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]: return tensor - def all_gather(self, tensor: Union[torch.Tensor, Number]) -> Union[torch.Tensor, Number]: + def all_gather(self, tensor: Union[torch.Tensor, Number]) -> Union[torch.Tensor, Number]: # type: ignore return tensor def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: @@ -281,5 +286,5 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: pass - def barrier(self): + def barrier(self) -> None: pass diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index 1bdcb1402ad..1b54f1f42f9 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Mapping, Optional, Tuple +from typing import Any, Callable, Mapping, Optional, Tuple import torch @@ -33,7 +33,7 @@ class _HorovodDistModel(ComputationModel): available_backends = (HOROVOD,) @staticmethod - def _get_hvd_rank(): + def _get_hvd_rank() -> int: try: rank = hvd.rank() except ValueError as e: @@ -48,7 +48,7 @@ def create_from_context() -> Optional["_HorovodDistModel"]: return _HorovodDistModel() @staticmethod - def create_from_backend(backend: str, **kwargs) -> "_HorovodDistModel": + def create_from_backend(backend: str, **kwargs: Any) -> "_HorovodDistModel": if backend not in _HorovodDistModel.available_backends: raise ValueError("Backend should be one of '{}'".format(_HorovodDistModel.available_backends)) @@ -57,7 +57,7 @@ def create_from_backend(backend: str, **kwargs) -> "_HorovodDistModel": raise RuntimeError("Can not re-initialize Horovod if it is already initialized") return _HorovodDistModel(do_init=True, **kwargs) - def __init__(self, do_init=False, **kwargs): + def __init__(self, do_init: bool = False, **kwargs: Any) -> None: """This is a private method. Please, use `create_from_backend` or `create_from_context` """ super(_HorovodDistModel, self).__init__() @@ -73,7 +73,7 @@ def __init__(self, do_init=False, **kwargs): self._setup_attrs() - def _compute_nproc_per_node(self): + def _compute_nproc_per_node(self) -> int: return hvd.local_size() def get_local_rank(self) -> int: @@ -103,11 +103,11 @@ def device(self) -> torch.device: def backend(self) -> str: return self._backend - def finalize(self): + def finalize(self) -> None: hvd.shutdown() @staticmethod - def _dist_worker_task_fn(backend, fn, args, kwargs_dict): + def _dist_worker_task_fn(backend: str, fn: Callable, args: Tuple, kwargs_dict: Mapping) -> None: from ignite.distributed.utils import _set_model, finalize model = _HorovodDistModel.create_from_backend(backend) @@ -116,15 +116,15 @@ def _dist_worker_task_fn(backend, fn, args, kwargs_dict): finalize() @staticmethod - def spawn( + def spawn( # type: ignore[override] fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, nproc_per_node: int = 1, - hosts=None, + hosts: Optional[str] = None, backend: str = HOROVOD, - **kwargs - ): + **kwargs: Any + ) -> None: c1 = "nnodes" in kwargs and kwargs["nnodes"] > 1 c2 = "node_rank" in kwargs and kwargs["node_rank"] > 0 if c1 or c2: @@ -166,7 +166,7 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: return hvd.broadcast(tensor, root_rank=src) - def barrier(self): + def barrier(self) -> None: # https://github.com/horovod/horovod/issues/159#issuecomment-424834603 # hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier") hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier") diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index be44b8211b5..e64ddeb7a69 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -2,7 +2,7 @@ import subprocess import warnings from distutils.version import LooseVersion -from typing import Callable, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple import torch import torch.distributed as dist @@ -48,22 +48,22 @@ def create_from_context() -> Optional["_NativeDistModel"]: return _NativeDistModel() @staticmethod - def create_from_backend(backend: str, **kwargs) -> "_NativeDistModel": + def create_from_backend(backend: str, **kwargs: Any) -> "_NativeDistModel": if dist.is_available() and dist.is_initialized(): raise RuntimeError("Can not create new distributed process group if default one is already initialized") return _NativeDistModel(backend=backend, **kwargs) - def __init__(self, backend=None, timeout=None, **kwargs): + def __init__(self, backend: Optional[str] = None, timeout: Optional[int] = None, **kwargs: Any) -> None: """This is a private method. Please, use `create_from_backend` or `create_from_context` """ super(_NativeDistModel, self).__init__() - self._env_backup = None + self._env_backup = None # type: Optional[Dict[str, str]] if backend is not None: self._create_from_backend(backend, timeout=timeout, **kwargs) else: self._init_from_context() - def _create_from_backend(self, backend, timeout=None, **kwargs): + def _create_from_backend(self, backend: str, timeout: Optional[int] = None, **kwargs: Any) -> None: if backend == dist.Backend.NCCL and not torch.cuda.is_available(): raise RuntimeError("Nccl backend is required but no cuda capable devices") @@ -71,8 +71,8 @@ def _create_from_backend(self, backend, timeout=None, **kwargs): self._local_rank = int(os.environ["LOCAL_RANK"]) # for debug purposes - self._master_port = int(os.environ["MASTER_PORT"]) - self._master_addr = os.environ["MASTER_ADDR"] + self._master_port = int(os.environ["MASTER_PORT"]) # type: Optional[int] + self._master_addr = os.environ["MASTER_ADDR"] # type: Optional[str] init_pg_kwargs = {} if timeout is not None: @@ -87,7 +87,7 @@ def _create_from_backend(self, backend, timeout=None, **kwargs): self._setup_attrs() - def _init_from_context(self): + def _init_from_context(self) -> None: self._identify_local_rank() @@ -96,39 +96,38 @@ def _init_from_context(self): self._master_addr = None self._setup_attrs() - def _compute_nproc_per_node(self): + def _compute_nproc_per_node(self) -> int: tensor = torch.tensor([self.get_local_rank() + 1]).to(self.device()) dist.all_reduce(tensor, op=dist.ReduceOp.MAX) - return tensor.item() + return int(tensor.item()) - def _get_all_hostnames(self): + def _get_all_hostnames(self) -> List[Tuple[str, ...]]: import socket device = "cpu" if self.backend() == dist.Backend.NCCL: index = torch.cuda.current_device() device = "cuda:{}".format(index) - name = socket.gethostname() - name = torch.tensor(bytearray(name, "utf-8")).to(device) + hostname = socket.gethostname() + name = torch.tensor(bytearray(hostname, "utf-8")).to(device) padded_t_name = torch.zeros(256, device=device, dtype=torch.long) padded_t_name[: len(name)] = name out_t_names = [torch.zeros_like(padded_t_name) for _ in range(self.get_world_size())] dist.all_gather(out_t_names, padded_t_name) - out_t_names = [tuple(t.cpu().tolist()) for t in out_t_names] - return out_t_names + return [tuple(t.cpu().tolist()) for t in out_t_names] @staticmethod - def _compute_node_and_local_ranks(rank, hostnames): + def _compute_node_and_local_ranks(rank: int, hostnames: List[Tuple[str, ...]]) -> Tuple[int, int]: from collections import Counter - c = Counter(hostnames) + c = Counter(hostnames) # type: Counter sizes = torch.tensor([0,] + list(c.values())) cumsum_sizes = torch.cumsum(sizes, dim=0) node_rank = (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item() local_rank = rank - cumsum_sizes[node_rank].item() - return local_rank, node_rank + return int(local_rank), node_rank - def _compute_local_rank_via_hostname(self): + def _compute_local_rank_via_hostname(self) -> int: # get all hostnames hostnames = self._get_all_hostnames() local_rank, self._node = self._compute_node_and_local_ranks(self.get_rank(), hostnames) @@ -142,7 +141,7 @@ def _compute_local_rank_via_hostname(self): ) return local_rank - def _identify_local_rank(self): + def _identify_local_rank(self) -> None: if "SLURM_JOBID" in os.environ: os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] @@ -161,7 +160,7 @@ def _identify_local_rank(self): # use socket gethostname heuristic to determine number of nodes => local rank self._local_rank = self._compute_local_rank_via_hostname() - def setup_env_vars(self): + def setup_env_vars(self) -> None: self._env_backup = os.environ.copy() @@ -184,7 +183,7 @@ def setup_env_vars(self): os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "15000") - def _setup_env_in_slurm(self): + def _setup_env_in_slurm(self) -> None: for k in ["SLURM_PROCID", "SLURM_LOCALID", "SLURM_NTASKS", "SLURM_JOB_NODELIST"]: if k not in os.environ: raise RuntimeError("SLURM distributed configuration is missing '{}' in env variables".format(k)) @@ -227,7 +226,7 @@ def device(self) -> torch.device: def backend(self) -> str: return dist.get_backend() - def finalize(self): + def finalize(self) -> None: dist.destroy_process_group() # restore backed-up env if self._env_backup is not None: @@ -236,8 +235,18 @@ def finalize(self): @staticmethod def _dist_worker_task_fn( - local_rank, backend, fn, args, kw_dict, world_size, nprocs_per_node, node_rank, master_addr, master_port, kw - ): + local_rank: int, + backend: str, + fn: Callable, + args: Tuple, + kw_dict: Mapping, + world_size: int, + nprocs_per_node: int, + node_rank: int, + master_addr: str, + master_port: str, + kw: Any, + ) -> None: from ignite.distributed.utils import _set_model, finalize copy_env_vars = os.environ.copy() @@ -257,7 +266,7 @@ def _dist_worker_task_fn( os.environ.update(copy_env_vars) @staticmethod - def spawn( + def spawn( # type: ignore[override] fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, @@ -267,8 +276,8 @@ def spawn( master_addr: str = "127.0.0.1", master_port: int = 2222, backend: str = "nccl", - **kwargs - ): + **kwargs: Any + ) -> None: world_size = nnodes * nproc_per_node spawn_kwargs = { @@ -327,5 +336,5 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: dist.broadcast(tensor, src=src) return tensor - def barrier(self): + def barrier(self) -> None: dist.barrier() diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index 27c38ee959c..533defdb61d 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -1,4 +1,4 @@ -from typing import Callable, Mapping, Optional, Tuple +from typing import Any, Callable, Mapping, Optional, Tuple import torch @@ -38,10 +38,10 @@ def create_from_context() -> Optional["_XlaDistModel"]: return _XlaDistModel() @staticmethod - def create_from_backend(backend: str = XLA_TPU, **kwargs) -> "_XlaDistModel": + def create_from_backend(backend: str = XLA_TPU, **kwargs: Any) -> "_XlaDistModel": return _XlaDistModel(backend=backend, **kwargs) - def __init__(self, backend=None, **kwargs): + def __init__(self, backend: Optional[str] = None, **kwargs: Any): """This is a private method. Please, use `create_from_backend` or `create_from_context` """ super(_XlaDistModel, self).__init__() @@ -50,17 +50,17 @@ def __init__(self, backend=None, **kwargs): else: self._init_from_context() - def _create_from_backend(self, backend, **kwargs): + def _create_from_backend(self, backend: str, **kwargs: Any) -> None: xm.rendezvous("init") self._backend = backend self._setup_attrs() - def _init_from_context(self): + def _init_from_context(self) -> None: self._backend = XLA_TPU self._setup_attrs() - def _compute_nproc_per_node(self): + def _compute_nproc_per_node(self) -> int: tensor = torch.tensor([self.get_local_rank() + 1.0], dtype=torch.float).to(self.device()) xm.all_reduce("max", [tensor,]) return int(tensor.item()) @@ -90,11 +90,13 @@ def device(self) -> torch.device: def backend(self) -> str: return self._backend - def finalize(self): + def finalize(self) -> None: pass @staticmethod - def _dist_worker_task_fn(local_rank, backend, fn, args, kwargs_dict): + def _dist_worker_task_fn( + local_rank: int, backend: str, fn: Callable, args: Tuple, kwargs_dict: Mapping + ) -> None: from ignite.distributed.utils import _set_model, finalize model = _XlaDistModel.create_from_backend(backend) @@ -103,7 +105,7 @@ def _dist_worker_task_fn(local_rank, backend, fn, args, kwargs_dict): finalize() @staticmethod - def spawn( + def spawn( # type: ignore[override] fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, @@ -111,8 +113,8 @@ def spawn( nnodes: int = 1, node_rank: int = 0, backend: str = XLA_TPU, - **kwargs - ): + **kwargs: Any + ) -> None: if "start_method" not in kwargs: kwargs["start_method"] = "fork" @@ -155,5 +157,5 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: xm.all_reduce("sum", [tensor,]) return tensor - def barrier(self): + def barrier(self) -> None: xm.rendezvous("barrier") diff --git a/ignite/distributed/launcher.py b/ignite/distributed/launcher.py index f170aa7516e..643650fd14f 100644 --- a/ignite/distributed/launcher.py +++ b/ignite/distributed/launcher.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Any, Callable, Dict, Optional from ignite.distributed import utils as idist from ignite.utils import setup_logger @@ -180,9 +180,9 @@ def __init__( nnodes: Optional[int] = None, node_rank: Optional[int] = None, master_addr: Optional[str] = None, - master_port: Optional[str] = None, - **spawn_kwargs - ): + master_port: Optional[int] = None, + **spawn_kwargs: Any + ) -> None: if backend is not None: if backend not in idist.available_backends(): raise ValueError( @@ -214,7 +214,14 @@ def __init__( self.logger.info("- Parameters to spawn processes: \n\t{}".format(msg)) @staticmethod - def _setup_spawn_params(nproc_per_node, nnodes, node_rank, master_addr, master_port, **spawn_kwargs): + def _setup_spawn_params( + nproc_per_node: int, + nnodes: Optional[int] = None, + node_rank: Optional[int] = None, + master_addr: Optional[str] = None, + master_port: Optional[int] = None, + **spawn_kwargs: Any + ) -> Dict: if nproc_per_node < 1: raise ValueError("Argument nproc_per_node should positive, but given {}".format(nproc_per_node)) if nnodes is None: @@ -244,7 +251,7 @@ def _setup_spawn_params(nproc_per_node, nnodes, node_rank, master_addr, master_p params.update(spawn_kwargs) return {k: v for k, v in params.items() if v is not None} - def run(self, func: Callable, *args, **kwargs): + def run(self, func: Callable, *args: Any, **kwargs: Any) -> None: """Execute ``func`` with provided arguments in distributed context. Example @@ -266,7 +273,7 @@ def training(local_rank, config, **kwargs): **kwargs: keyword arguments of ``func``. """ - if self._spawn_params is not None: + if self._spawn_params is not None and self.backend is not None: self.logger.info("Spawn function '{}' in {} processes".format(func, self._spawn_params["nproc_per_node"])) idist.spawn(self.backend, func, args=args, kwargs_dict=kwargs, **self._spawn_params) else: @@ -276,7 +283,7 @@ def training(local_rank, config, **kwargs): self.logger.info("End of run") - def __enter__(self): + def __enter__(self) -> "Parallel": if (self.backend is not None) and self._spawn_params is None: idist.initialize(self.backend) self.logger = setup_logger(__name__ + "." + self.__class__.__name__) @@ -284,7 +291,7 @@ def __enter__(self): return self - def __exit__(self, *args, **kwargs): + def __exit__(self, *args: Any, **kwargs: Any) -> None: if (self.backend is not None) and self._spawn_params is None: self.logger.info("Finalized processing group with backend: '{}'".format(self.backend)) idist.finalize() diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 60481d03952..ec17dbe9a29 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -1,7 +1,7 @@ import socket from functools import wraps from numbers import Number -from typing import Callable, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, List, Mapping, Optional, Tuple, Union import torch @@ -48,7 +48,7 @@ _need_to_sync = True -def sync(temporary=False): +def sync(temporary: bool = False) -> None: """Helper method to force this module to synchronize with current distributed context. This method should be used when distributed context is manually created or destroyed. @@ -102,10 +102,10 @@ def backend() -> Optional[str]: return _model.backend() -def available_backends() -> Tuple[str]: +def available_backends() -> Tuple[str, ...]: """Returns available backends. """ - out = () + out = () # type: Tuple[str, ...] for m in registered_computation_models: out += m.available_backends return out @@ -190,8 +190,13 @@ def hostname() -> str: def spawn( - backend: str, fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, nproc_per_node: int = 1, **kwargs -): + backend: str, + fn: Callable, + args: Tuple, + kwargs_dict: Optional[Mapping] = None, + nproc_per_node: int = 1, + **kwargs: Any +) -> None: """Spawns ``nproc_per_node`` processes that run ``fn`` with ``args``/``kwargs_dict`` and initialize distributed configuration defined by ``backend``. @@ -344,7 +349,7 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, if _need_to_sync and isinstance(_model, _SerialModel): sync(temporary=True) - return _model.all_gather(tensor) + return _model.all_gather(tensor) # type: ignore[arg-type] def broadcast(tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: @@ -390,7 +395,7 @@ def broadcast(tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[t return _model.broadcast(tensor, src=src) -def barrier(): +def barrier() -> None: """Helper method to synchronize all processes. """ if _need_to_sync and isinstance(_model, _SerialModel): @@ -399,7 +404,7 @@ def barrier(): _model.barrier() -def set_local_rank(index: int): +def set_local_rank(index: int) -> None: """Method to hint the local rank in case if torch native distributed context is created by user without using :meth:`~ignite.distributed.initialize` or :meth:`~ignite.distributed.spawn`. @@ -427,7 +432,7 @@ def run(local_rank, *args, **kwargs): ComputationModel._ext_local_rank = index -def _set_model(model, temporary=False): +def _set_model(model: Any, temporary: bool = False) -> None: global _model, _need_to_sync _model = model _need_to_sync = True @@ -435,13 +440,13 @@ def _set_model(model, temporary=False): _need_to_sync = False -def _assert_backend(backend): +def _assert_backend(backend: str) -> None: backends = available_backends() if backend not in backends: raise ValueError("Backend should be one of '{}'".format(backends)) -def initialize(backend: str, **kwargs): +def initialize(backend: str, **kwargs: Any) -> None: """Initializes distributed configuration according to provided ``backend`` Examples: @@ -495,7 +500,7 @@ def train_fn(local_rank, a, b, c): _set_model(comp_model_cls(backend, **kwargs)) -def finalize(): +def finalize() -> None: """Finalizes distributed configuration. For example, in case of native pytorch distributed configuration, it calls ``dist.destroy_process_group()``. """ @@ -503,7 +508,7 @@ def finalize(): _set_model(_SerialModel()) -def show_config(): +def show_config() -> None: """Helper method to display distributed configuration via ``logging``. """ @@ -522,7 +527,7 @@ def show_config(): logger.info("node rank: {}".format(get_node_rank())) -def one_rank_only(rank: int = 0, with_barrier: bool = False): +def one_rank_only(rank: int = 0, with_barrier: bool = False) -> Callable: """Decorator to filter handlers wrt a rank number Args: @@ -544,9 +549,9 @@ def some_handler(_): ... """ - def _one_rank_only(func): + def _one_rank_only(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Optional[Any]: ret = None if get_rank() == rank: ret = func(*args, **kwargs) diff --git a/mypy.ini b/mypy.ini index e0372c8029f..24c1a726254 100644 --- a/mypy.ini +++ b/mypy.ini @@ -19,9 +19,11 @@ ignore_errors = True ignore_errors = True -[mypy-ignite.distributed.*] - -ignore_errors = True +[mypy-horovod.*] +ignore_missing_imports = True [mypy-numpy.*] ignore_missing_imports = True + +[mypy-torch_xla.*] +ignore_missing_imports = True