Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activate mypy in ignite.distributed #1355

Merged
merged 7 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions ignite/distributed/auto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import Any, Callable, Iterator, List, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -16,7 +17,7 @@
__all__ = ["auto_dataloader", "auto_model", "auto_optim", "DistributedProxySampler"]


def auto_dataloader(dataset, **kwargs):
def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, "_MpDeviceLoader"]:
"""Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).

Expand Down Expand Up @@ -74,7 +75,9 @@ def auto_dataloader(dataset, **kwargs):

if "batch_sampler" not in kwargs:
if kwargs.get("sampler", None) is not None:
sampler = DistributedProxySampler(kwargs["sampler"], num_replicas=world_size, rank=rank)
sampler = DistributedProxySampler(
kwargs["sampler"], num_replicas=world_size, rank=rank
) # type: Union[DistributedProxySampler, DistributedSampler]
else:
sampler = DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=kwargs.get("shuffle", True)
Expand All @@ -101,7 +104,7 @@ def auto_dataloader(dataset, **kwargs):
kwargs["pin_memory"] = kwargs.get("pin_memory", "cuda" in idist.device().type)

logger.info("Use data loader kwargs for dataset '{}': \n\t{}".format(repr(dataset)[:20].strip(), kwargs))
dataloader = DataLoader(dataset, **kwargs)
dataloader = DataLoader(dataset, **kwargs) # type: Union[DataLoader, "_MpDeviceLoader"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it is DataLoader type without _MpDeviceLoader

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this specific line you are right, but a few lines later it can be reassigned with _MpDeviceLoader and results in following error

ignite/distributed/auto.py:122: error: Incompatible types in assignment (expression has type "_MpDeviceLoader", variable has type "DataLoader[Any]")  [assignment]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, what if we do something like

dataloader = DataLoader(dataset, **kwargs)  # type: DataLoader

if idist.has_xla_support and idist.backend() == idist_xla.XLA_TPU and world_size > 1:
    sampler = dataloader.sampler  # type: ignore[union-attr]
    mp_dataloader = mp_device_loader_cls(dataloader, idist.device())  # type: "_MpDeviceLoader"
    mp_dataloader.sampler = sampler  # type: ignore[attr-defined]
    return mp_dataloader

return dataloader

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, should work. let me try it.


if idist.has_xla_support and idist.backend() == idist_xla.XLA_TPU and world_size > 1:

Expand All @@ -115,7 +118,7 @@ def auto_dataloader(dataset, **kwargs):
except ImportError:
pass

sampler = dataloader.sampler
sampler = dataloader.sampler # type: ignore[union-attr]
dataloader = mp_device_loader_cls(dataloader, idist.device())
dataloader.sampler = sampler

Expand Down Expand Up @@ -266,22 +269,30 @@ class DistributedProxySampler(DistributedSampler):

"""

def __init__(self, sampler: Sampler, num_replicas=None, rank=None):
def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None, rank: Optional[int] = None) -> None:

if not isinstance(sampler, Sampler):
raise TypeError("Argument sampler should be instance of torch Sampler, but given: {}".format(type(sampler)))

if not hasattr(sampler, "__len__"):
raise TypeError("Argument sampler should have length")

super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False)
super(DistributedProxySampler, self).__init__(
sampler, num_replicas=num_replicas, rank=rank, shuffle=False # type: ignore[arg-type]
)
self.sampler = sampler

def __iter__(self):
def __setattr__(self, name: str, value: Any) -> None:
gruebel marked this conversation as resolved.
Show resolved Hide resolved
super().__setattr__(name, value)

def __getattr__(self, name: str) -> Any:
super().__getattribute__(name)

def __iter__(self) -> Iterator:
# deterministically shuffle based on epoch
torch.manual_seed(self.epoch)

indices = []
indices = [] # type: List
while len(indices) < self.total_size:
indices += list(self.sampler)

Expand All @@ -304,22 +315,28 @@ def __iter__(self):
class _MpDeviceLoader:
# https://github.com/pytorch/xla/pull/2117
# From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available
def __init__(self, loader, device, **kwargs):
def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
self._loader = loader
self._device = device
self._parallel_loader_kwargs = kwargs

def __iter__(self):
def __setattr__(self, name: str, value: Any) -> None:
super().__setattr__(name, value)

def __getattr__(self, name: str) -> Any:
super().__getattribute__(name)

def __iter__(self) -> Iterator:
parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs)
return parallel_loader.per_device_loader(self._device)

def __len__(self):
def __len__(self) -> int:
return len(self._loader)

class _XLADistributedOptimizer(Optimizer):
def __init__(self, optimizer):
super(self.__class__, self).__init__(optimizer.param_groups)
def __init__(self, optimizer: Optimizer) -> None:
super(self.__class__, self).__init__(optimizer.param_groups, {})
self.wrapped_optimizer = optimizer

def step(self, closure=None):
def step(self, closure: Optional[Callable] = None) -> None:
xm.optimizer_step(self.wrapped_optimizer, barrier=True)
2 changes: 1 addition & 1 deletion ignite/distributed/comp_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ignite.distributed.comp_models.xla import has_xla_support


def setup_available_computation_models():
def setup_available_computation_models(): # type: ignore # inhomogeneous Tuple types are not supported
models = [
_SerialModel,
]
Expand Down
45 changes: 25 additions & 20 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from numbers import Number
from typing import Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Union, cast

import torch

Expand All @@ -13,15 +13,15 @@ class ComputationModel(metaclass=ABCMeta):
"""

# this is an additional local rank storage used when idist is setup from existing native torch dist context
_ext_local_rank = None
_ext_local_rank = None # type: Optional[int]

def __init__(self):
self._backend = None
self._nproc_per_node = None
self._nnodes = None
self._node = None

def _setup_attrs(self):
def _setup_attrs(self) -> None:
if self._nproc_per_node is None:
self._nproc_per_node = self._compute_nproc_per_node() if self.get_world_size() > 1 else 1
if self._nnodes is None:
Expand Down Expand Up @@ -66,7 +66,7 @@ def backend(self) -> Optional[str]:
pass

@abstractmethod
def finalize(self):
def finalize(self) -> None:
pass

@staticmethod
Expand All @@ -76,15 +76,15 @@ def create_from_context() -> Optional["ComputationModel"]:

@staticmethod
@abstractmethod
def create_from_backend(backend: str, **kwargs) -> "ComputationModel":
def create_from_backend(backend: str, **kwargs: Any) -> "ComputationModel":
pass

@staticmethod
@abstractmethod
def spawn(*args, **kwargs):
def spawn(*args: Any, **kwargs: Any) -> None:
pass

_collective_op_dtype = None
_collective_op_dtype = None # type: Any

@staticmethod
def _encode_str(x: str, device: torch.device) -> torch.Tensor:
Expand All @@ -107,7 +107,9 @@ def _decode_str(xs: torch.Tensor) -> List[str]:
out = [bytearray(x[: x[-1]].tolist()).decode("utf-8") for x in xs]
return out

def _apply_op(self, tensor: torch.Tensor, device: torch.device, fn: Callable, *args, **kwargs) -> torch.Tensor:
def _apply_op(
self, tensor: torch.Tensor, device: torch.device, fn: Callable, *args: Any, **kwargs: Any
) -> torch.Tensor:
out_dtype = None
tensor_device = None

Expand All @@ -133,7 +135,7 @@ def _apply_op(self, tensor: torch.Tensor, device: torch.device, fn: Callable, *a
return tensor

def _collective_op(
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args, **kwargs
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
) -> Union[torch.Tensor, Number, List[str]]:
tensor_to_number = tensor_to_str = False
device = self.device()
Expand All @@ -147,7 +149,7 @@ def _collective_op(
tensor = self._apply_op(tensor, device, fn, *args, **kwargs)

if tensor_to_number and tensor.numel() == 1:
return tensor.item()
return cast(Number, tensor.item())
elif tensor_to_str:
return self._decode_str(tensor)
return tensor
Expand All @@ -156,7 +158,7 @@ def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Un
if not isinstance(tensor, (torch.Tensor, Number)):
raise TypeError("Unhandled input type {}".format(type(tensor)))

return self._collective_op(tensor, self._do_all_reduce, op)
return cast(Union[torch.Tensor, Number], self._collective_op(tensor, self._do_all_reduce, op))

def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]:
if not isinstance(tensor, (torch.Tensor, Number, str)):
Expand Down Expand Up @@ -189,7 +191,7 @@ def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> U
tensor = self._apply_op(tensor, device, self._do_broadcast, src)

if tensor_to_number:
return tensor.item()
return cast(Number, tensor.item())
if tensor_to_str:
list_str = self._decode_str(tensor)
return list_str[0]
Expand All @@ -208,7 +210,7 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass

@abstractmethod
def barrier(self):
def barrier(self) -> None:
pass


Expand All @@ -219,6 +221,9 @@ class _SerialModel(ComputationModel):
name = "serial"
available_backends = ()

def __init__(self, _backend: Optional[str] = None, **_kwargs: Any) -> None:
super(_SerialModel, self).__init__()

def get_local_rank(self) -> int:
return 0

Expand All @@ -242,10 +247,10 @@ def device(self) -> torch.device:
return torch.device("cuda")
return torch.device("cpu")

def backend(self) -> None:
def backend(self) -> Optional[str]:
return None

def finalize(self):
def finalize(self) -> None:
pass

def _compute_nproc_per_node(self) -> int:
Expand All @@ -256,18 +261,18 @@ def create_from_context() -> "_SerialModel":
return _SerialModel()

@staticmethod
def create_from_backend(backend: Optional[str] = None, **kwargs) -> "_SerialModel":
def create_from_backend(backend: Optional[str] = None, **kwargs: Any) -> "_SerialModel":
return _SerialModel()

@staticmethod
def spawn(*args, **kwargs):
def spawn(*args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Serial computation model does not implement spawn method")

def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]:
return tensor

def all_gather(self, tensor: Union[torch.Tensor, Number]) -> Union[torch.Tensor, Number]:
return tensor
def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]:
gruebel marked this conversation as resolved.
Show resolved Hide resolved
return cast(Union[torch.Tensor, Number], tensor)

def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
return tensor
Expand All @@ -281,5 +286,5 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass

def barrier(self):
def barrier(self) -> None:
pass
24 changes: 12 additions & 12 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Callable, Mapping, Optional, Tuple
from typing import Any, Callable, Mapping, Optional, Tuple

import torch

Expand Down Expand Up @@ -33,7 +33,7 @@ class _HorovodDistModel(ComputationModel):
available_backends = (HOROVOD,)

@staticmethod
def _get_hvd_rank():
def _get_hvd_rank() -> int:
try:
rank = hvd.rank()
except ValueError as e:
Expand All @@ -48,7 +48,7 @@ def create_from_context() -> Optional["_HorovodDistModel"]:
return _HorovodDistModel()

@staticmethod
def create_from_backend(backend: str, **kwargs) -> "_HorovodDistModel":
def create_from_backend(backend: str, **kwargs: Any) -> "_HorovodDistModel":
if backend not in _HorovodDistModel.available_backends:
raise ValueError("Backend should be one of '{}'".format(_HorovodDistModel.available_backends))

Expand All @@ -57,7 +57,7 @@ def create_from_backend(backend: str, **kwargs) -> "_HorovodDistModel":
raise RuntimeError("Can not re-initialize Horovod if it is already initialized")
return _HorovodDistModel(do_init=True, **kwargs)

def __init__(self, do_init=False, **kwargs):
def __init__(self, do_init: bool = False, **kwargs: Any) -> None:
"""This is a private method. Please, use `create_from_backend` or `create_from_context`
"""
super(_HorovodDistModel, self).__init__()
Expand All @@ -73,7 +73,7 @@ def __init__(self, do_init=False, **kwargs):

self._setup_attrs()

def _compute_nproc_per_node(self):
def _compute_nproc_per_node(self) -> int:
return hvd.local_size()

def get_local_rank(self) -> int:
Expand Down Expand Up @@ -103,11 +103,11 @@ def device(self) -> torch.device:
def backend(self) -> str:
return self._backend

def finalize(self):
def finalize(self) -> None:
hvd.shutdown()

@staticmethod
def _dist_worker_task_fn(backend, fn, args, kwargs_dict):
def _dist_worker_task_fn(backend: str, fn: Callable, args: Tuple, kwargs_dict: Mapping) -> None:
from ignite.distributed.utils import _set_model, finalize

model = _HorovodDistModel.create_from_backend(backend)
Expand All @@ -116,15 +116,15 @@ def _dist_worker_task_fn(backend, fn, args, kwargs_dict):
finalize()

@staticmethod
def spawn(
def spawn( # type: ignore[override]
fn: Callable,
args: Tuple,
kwargs_dict: Optional[Mapping] = None,
nproc_per_node: int = 1,
hosts=None,
hosts: Optional[str] = None,
backend: str = HOROVOD,
**kwargs
):
**kwargs: Any
) -> None:
c1 = "nnodes" in kwargs and kwargs["nnodes"] > 1
c2 = "node_rank" in kwargs and kwargs["node_rank"] > 0
if c1 or c2:
Expand Down Expand Up @@ -166,7 +166,7 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return hvd.broadcast(tensor, root_rank=src)

def barrier(self):
def barrier(self) -> None:
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
Loading