Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing of distributed.comp_modules.utils.all_gather #1368

Closed
gruebel opened this issue Oct 6, 2020 · 3 comments · Fixed by #1370
Closed

Improve typing of distributed.comp_modules.utils.all_gather #1368

gruebel opened this issue Oct 6, 2020 · 3 comments · Fixed by #1370

Comments

@gruebel
Copy link
Contributor

gruebel commented Oct 6, 2020

🚀 Feature

As a followup to the issue #1344 and comment #1355 (comment) it is needed to rework the typing of the utils.all_gather and the different realizations serial, native, hvd, xla.

@gruebel
Copy link
Contributor Author

gruebel commented Oct 6, 2020

@vfdev-5 I took a look on the code again and your suggested typing, is not feasible.

def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, List[Number], List[str]]:

Looking at the definition of all_gather inside the ComputationModel class

def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]:
if not isinstance(tensor, (torch.Tensor, Number, str)):
raise TypeError("Unhandled input type {}".format(type(tensor)))
return self._collective_op(tensor, self._do_all_gather)

there is a reference to the function _collective_op and there is no possibility to get a result of type List[Number]
def _collective_op(
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
) -> Union[torch.Tensor, Number, List[str]]:
tensor_to_number = tensor_to_str = False
device = self.device()
if isinstance(tensor, Number):
tensor_to_number = True
tensor = torch.tensor(tensor, device=device, dtype=self._collective_op_dtype)
elif isinstance(tensor, str):
tensor_to_str = True
tensor = self._encode_str(tensor, device)
tensor = self._apply_op(tensor, device, fn, *args, **kwargs)
if tensor_to_number and tensor.numel() == 1:
return cast(Number, tensor.item())
elif tensor_to_str:
return self._decode_str(tensor)
return tensor

except I change the return in line 152 to

if tensor_to_number and tensor.numel() == 1:
    return cast(List[Number], [tensor.item()])

If it is ok for you, then I would make the change and adjust all the typings and create the PR.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 6, 2020

@gruebel you are right, _collective_op implementation does not handle this case properly. Let me detail it, in case of all gather, collective op does the following for the input as number:

Number -> Tensor[1] -> Tensor[N] returns Tensor[N]

comparing with in case of all reduce, collective op does:

Number -> Tensor[1] -> Tensor[1] -> Number returns Number

Probably, we can add the following case

if tensor_to_number:
    if tensor.numel() == 1:
        return cast(Number, tensor.item())
    else:
        return case(List[Number], tensor.tolist())

I think "return" type annotations of _collective_op is not complete, I'd say it should be

    def _collective_op(...) -> Union[torch.Tensor, Number, List[Number], List[str]]:

and cover all reduce and all gather cases.

@gruebel
Copy link
Contributor Author

gruebel commented Oct 6, 2020

ok, that makes sense. I also saw the reference to ~Tensor.tolist there. I will give it try 💪

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants