Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Update torchvision transform wrapper #1595

Merged
merged 3 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmpretrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
mmcv_maximum_version = '2.1.0'
mmcv_version = digit_version(mmcv.__version__)

mmengine_minimum_version = '0.7.1'
mmengine_minimum_version = '0.7.3'
mmengine_maximum_version = '1.0.0'
mmengine_version = digit_version(mmengine.__version__)

Expand Down
44 changes: 12 additions & 32 deletions mmpretrain/datasets/transforms/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numbers
import re
import string
import traceback
from enum import EnumMeta
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -45,43 +44,23 @@ def _interpolation_modes_from_str(t: str):
return inverse_modes_mapping[t]


def _warpper_vision_transform_cls(vision_transform_cls, new_name):
"""build a transform warpper class for specific torchvison.transform to
handle the different input type between torchvison.transforms with
mmcls.datasets.transforms."""
class TorchVisonTransformWrapper:

def new_init(self, *args, **kwargs):
def __init__(self, transform, *args, **kwargs):
if 'interpolation' in kwargs and isinstance(kwargs['interpolation'],
str):
kwargs['interpolation'] = _interpolation_modes_from_str(
kwargs['interpolation'])
if 'dtype' in kwargs and isinstance(kwargs['dtype'], str):
kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype'])
self.t = transform(*args, **kwargs)

try:
self.t = vision_transform_cls(*args, **kwargs)
except TypeError as e:
traceback.print_exc()
raise TypeError(
f'Error when init the {vision_transform_cls}, please '
f'check the argmemnts of {args} and {kwargs}. \n{e}')

def new_call(self, input):
try:
input['img'] = self.t(input['img'])
except Exception as e:
traceback.print_exc()
raise Exception('Error when processing of transform(`torhcvison/'
f'{vision_transform_cls.__name__}`). \n{e}')
return input

def new_str(self):
return str(self.t)
def __call__(self, results):
results['img'] = self.t(results['img'])
return results

new_transforms_cls = type(
new_name, (),
dict(__init__=new_init, __call__=new_call, __str__=new_str))
return new_transforms_cls
def __repr__(self) -> str:
return f'TorchVision{repr(self.t)}'


def register_vision_transforms() -> List[str]:
Expand All @@ -99,10 +78,11 @@ def register_vision_transforms() -> List[str]:
_transform = getattr(torchvision.transforms, module_name)
if inspect.isclass(_transform) and callable(
_transform) and not isinstance(_transform, (EnumMeta)):
new_cls = _warpper_vision_transform_cls(
_transform, f'TorchVison{module_name}')
from functools import partial
TRANSFORMS.register_module(
module=new_cls, name=f'torchvision/{module_name}')
module=partial(
TorchVisonTransformWrapper, transform=_transform),
name=f'torchvision/{module_name}')
vision_transforms.append(f'torchvision/{module_name}')
return vision_transforms

Expand Down
2 changes: 1 addition & 1 deletion requirements/mminstall.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mmcv>=2.0.0rc4,<2.1.0
mmengine>=0.4.0,<1.0.0
mmengine>=0.7.3,<1.0.0
2 changes: 1 addition & 1 deletion tests/test_datasets/test_transforms/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,4 +956,4 @@ def test_repr(self):
mmcls_trans = TRANSFORMS.build(
dict(type='torchvision/RandomResizedCrop', size=224))

self.assertEqual(str(vision_trans), str(mmcls_trans))
self.assertEqual(f'TorchVision{repr(vision_trans)}', repr(mmcls_trans))