Skip to content

Commit

Permalink
fix the scatter when input is cpu (#1621)
Browse files Browse the repository at this point in the history
* fix the scatter when input is cpu

* Update _functions.py

Add spaces to comply with the code specification

* add unittests

* add a blank line

* fix unittet

Co-authored-by: zhouzaida <[email protected]>
  • Loading branch information
Bilibilee and zhouzaida authored Jan 24, 2022
1 parent b8d7833 commit 0448fcf
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 5 deletions.
7 changes: 2 additions & 5 deletions mmcv/parallel/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ def scatter(input, devices, streams=None):
if devices != [-1]:
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
else:
# unsqueeze the first dimension thus the tensor's shape is the
# same as those scattered with GPU.
output = output.unsqueeze(0)

return output
else:
raise Exception(f'Unknown type {type(input)}.')
Expand Down Expand Up @@ -76,4 +73,4 @@ def forward(target_gpus, input):
if streams is not None:
synchronize_stream(outputs, target_gpus, streams)

return tuple(outputs)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
82 changes: 82 additions & 0 deletions tests/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel

from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
MMDistributedDataParallel, is_module_wrapper)
from mmcv.parallel._functions import Scatter, get_input_device, scatter
from mmcv.parallel.distributed_deprecated import \
MMDistributedDataParallel as DeprecatedMMDDP

Expand Down Expand Up @@ -64,3 +66,83 @@ def forward(self, *args, **kwargs):

module_wraper = ModuleWrapper(model)
assert is_module_wrapper(module_wraper)


def test_get_input_device():
# if the device is CPU, return -1
input = torch.zeros([1, 3, 3, 3])
assert get_input_device(input) == -1
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
assert get_input_device(inputs) == -1

# if the device is GPU, return the index of device
if torch.cuda.is_available():
input = torch.zeros([1, 3, 3, 3]).cuda()
assert get_input_device(input) == 0
inputs = [
torch.zeros([1, 3, 3, 3]).cuda(),
torch.zeros([1, 4, 4, 4]).cuda()
]
assert get_input_device(inputs) == 0

# input should be a tensor or list of tensor
with pytest.raises(Exception):
get_input_device(5)


def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)

inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)

# if the device is GPU, copy the input from CPU to GPU
if torch.cuda.is_available():
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.cuda(), output)

inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.cuda(), output)

# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])


def test_Scatter():
# if the device is CPU, just return the input
target_gpus = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_gpus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])

target_gpus = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_gpus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)

# if the device is GPU, copy the input from CPU to GPU
if torch.cuda.is_available():
target_gpus = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_gpus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.cuda(), outputs[0])

target_gpus = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_gpus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.cuda(), output[0])

0 comments on commit 0448fcf

Please sign in to comment.