From 8ca448e6cfa91860affea917c1259fac296127e0 Mon Sep 17 00:00:00 2001 From: Adel Hasic Date: Mon, 18 Mar 2024 11:38:46 +0100 Subject: [PATCH] fix backprop issue in mobilenetv3lite MobileNetV3Lite failed with a RuntimeError during the first backprop stage. The error was caused in the final sequential stage of the network by a version mismatch between the ReLU and Dropout layers. Setting `inplace=False` in the Dropout layer realigns the versions and enables backprop to function correctly. The remaining changesets are triggered by formatter. --- .../xvision/models/mobilenetv3.py | 212 ++++++++++++------ 1 file changed, 142 insertions(+), 70 deletions(-) diff --git a/references/edgeailite/edgeai_xvision/xvision/models/mobilenetv3.py b/references/edgeailite/edgeai_xvision/xvision/models/mobilenetv3.py index ec871960194..88f3451e749 100644 --- a/references/edgeailite/edgeai_xvision/xvision/models/mobilenetv3.py +++ b/references/edgeailite/edgeai_xvision/xvision/models/mobilenetv3.py @@ -63,19 +63,27 @@ # ################################################################################# +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence + import torch -from functools import partial +from edgeai_torchmodelopt import xnn from torch import nn, Tensor from torch.nn import functional as F -from typing import Any, Callable, Dict, List, Optional, Sequence -from .utils import load_state_dict_from_url from .mobilenetv2 import _make_divisible, ConvBNActivation -from edgeai_torchmodelopt import xnn +from .utils import load_state_dict_from_url -__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small", "MobileNetV3Lite", "mobilenet_v3_lite_large", "mobilenet_v3_lite_small"] +__all__ = [ + "MobileNetV3", + "mobilenet_v3_large", + "mobilenet_v3_small", + "MobileNetV3Lite", + "mobilenet_v3_lite_large", + "mobilenet_v3_lite_small", +] model_urls = { @@ -89,8 +97,8 @@ def get_config(): model_config = xnn.utils.ConfigNode() model_config.input_channels = 3 model_config.num_classes = 1000 - model_config.width_mult = 1. - model_config.strides = None #(2,2,2,2,2) + model_config.width_mult = 1.0 + model_config.strides = None # (2,2,2,2,2) model_config.enable_fp16 = False return model_config @@ -118,8 +126,18 @@ def forward(self, input: Tensor) -> Tensor: class InvertedResidualConfig: # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper - def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, - activation: str, stride: int, dilation: int, width_mult: float): + def __init__( + self, + input_channels: int, + kernel: int, + expanded_channels: int, + out_channels: int, + use_se: bool, + activation: str, + stride: int, + dilation: int, + width_mult: float, + ): self.input_channels = self.adjust_channels(input_channels, width_mult) self.kernel = kernel self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) @@ -136,11 +154,15 @@ def adjust_channels(channels: int, width_mult: float): class InvertedResidual(nn.Module): # Implemented as described at section 5 of MobileNetV3 paper - def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], - se_layer: Callable[..., nn.Module] = SqueezeExcitation): + def __init__( + self, + cnf: InvertedResidualConfig, + norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = SqueezeExcitation, + ): super().__init__() if not (1 <= cnf.stride <= 2): - raise ValueError('illegal stride value') + raise ValueError("illegal stride value") self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels @@ -149,20 +171,43 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod # expand if cnf.expanded_channels != cnf.input_channels: - layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append( + ConvBNActivation( + cnf.input_channels, + cnf.expanded_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) # depthwise stride = 1 if cnf.dilation > 1 else cnf.stride - layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, - stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append( + ConvBNActivation( + cnf.expanded_channels, + cnf.expanded_channels, + kernel_size=cnf.kernel, + stride=stride, + dilation=cnf.dilation, + groups=cnf.expanded_channels, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) if cnf.use_se: layers.append(se_layer(cnf.expanded_channels)) # project - layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, - activation_layer=nn.Identity)) + layers.append( + ConvBNActivation( + cnf.expanded_channels, + cnf.out_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.Identity, + ) + ) self.block = nn.Sequential(*layers) self.out_channels = cnf.out_channels @@ -176,16 +221,15 @@ def forward(self, input: Tensor) -> Tensor: class MobileNetV3(nn.Module): - def __init__( - self, - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - #num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - activation_layer: Optional[Callable[..., nn.Module]] = nn.Hardswish, - **kwargs: Dict + self, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + # num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = nn.Hardswish, + **kwargs: Dict, ) -> None: """ MobileNet V3 main class @@ -196,20 +240,23 @@ def __init__( num_classes (int): Number of classes block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use + activation_layer (Optional[Callable[..., nn.Module]]): Module specifying the acitvation layer to use """ model_config = get_config() - if 'model_config' in list(kwargs.keys()): - model_config = model_config.merge_from(kwargs['model_config']) + if "model_config" in list(kwargs.keys()): + model_config = model_config.merge_from(kwargs["model_config"]) # - strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2) + strides = model_config.strides if (model_config.strides is not None) else (2, 2, 2, 2, 2) super().__init__() self.num_classes = model_config.num_classes self.enable_fp16 = model_config.enable_fp16 if not inverted_residual_setting: raise ValueError("The inverted_residual_setting should not be empty") - elif not (isinstance(inverted_residual_setting, Sequence) and - all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): + elif not ( + isinstance(inverted_residual_setting, Sequence) + and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting]) + ): raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") if block is None: @@ -222,8 +269,16 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels - layers.append(ConvBNActivation(model_config.input_channels, firstconv_output_channels, kernel_size=3, stride=strides[0], norm_layer=norm_layer, - activation_layer=activation_layer)) + layers.append( + ConvBNActivation( + model_config.input_channels, + firstconv_output_channels, + kernel_size=3, + stride=strides[0], + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) # building inverted residual blocks for cnf in inverted_residual_setting: @@ -232,21 +287,28 @@ def __init__( # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels - layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append( + ConvBNActivation( + lastconv_input_channels, + lastconv_output_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Linear(lastconv_output_channels, last_channel), activation_layer(inplace=False), - nn.Dropout(p=0.2, inplace=True), + nn.Dropout(p=0.2, inplace=False), nn.Linear(last_channel, model_config.num_classes), ) for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') + nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): @@ -271,18 +333,18 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str='HS', **params: Dict[str, Any]): +def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str = "HS", **params: Dict[str, Any]): # non-public config parameters - reduce_divider = 2 if params.pop('_reduced_tail', False) else 1 - dilation = 2 if params.pop('_dilated', False) else 1 - width_mult = params.pop('_width_mult', 1.0) + reduce_divider = 2 if params.pop("_reduced_tail", False) else 1 + dilation = 2 if params.pop("_dilated", False) else 1 + width_mult = params.pop("_width_mult", 1.0) model_config = get_config() - if 'model_config' in list(params.keys()): - model_config = model_config.merge_from(params['model_config']) + if "model_config" in list(params.keys()): + model_config = model_config.merge_from(params["model_config"]) width_mult = max(width_mult, model_config.width_mult) # - strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2) + strides = model_config.strides if (model_config.strides is not None) else (2, 2, 2, 2, 2) bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) @@ -302,8 +364,12 @@ def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str='HS', **para bneck_conf(80, 3, 480, 112, use_se, hs_type, 1, 1), bneck_conf(112, 3, 672, 112, use_se, hs_type, 1, 1), bneck_conf(112, 5, 672, 160 // reduce_divider, use_se, hs_type, strides[4], dilation), # C4 - bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation), - bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation), + bneck_conf( + 160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation + ), + bneck_conf( + 160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation + ), ] last_channel = adjust_channels(1280 // reduce_divider) # C5 elif arch in ("mobilenet_v3_small", "mobilenet_v3_lite_small"): @@ -317,8 +383,12 @@ def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str='HS', **para bneck_conf(40, 5, 120, 48, use_se, hs_type, 1, 1), bneck_conf(48, 5, 144, 48, use_se, hs_type, 1, 1), bneck_conf(48, 5, 288, 96 // reduce_divider, use_se, hs_type, strides[4], dilation), # C4 - bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation), - bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation), + bneck_conf( + 96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation + ), + bneck_conf( + 96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation + ), ] last_channel = adjust_channels(1024 // reduce_divider) # C5 else: @@ -333,7 +403,7 @@ def _mobilenet_v3_model( last_channel: int, pretrained: bool, progress: bool, - **kwargs: Any + **kwargs: Any, ): model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) if pretrained is True: @@ -346,8 +416,8 @@ def _mobilenet_v3_model( model.load_state_dict(state_dict) elif isinstance(pretrained, str): state_dict = torch.load(pretrained) - state_dict = state_dict['model'] if 'model' in state_dict else state_dict - state_dict = state_dict['state_dict'] if 'state_dict' in state_dict else state_dict + state_dict = state_dict["model"] if "model" in state_dict else state_dict + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict model.load_state_dict(state_dict) return model @@ -384,21 +454,23 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs ################################################### class MobileNetV3Lite(MobileNetV3): def __init__( - self, - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - activation_layer: Optional[Callable[..., nn.Module]] = nn.ReLU, - **kwargs + self, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = nn.ReLU, + **kwargs, ) -> None: - kwargs['num_classes'] = kwargs.get('num_classes', 1000) - super().__init__(inverted_residual_setting, - last_channel=last_channel, - block=block, - norm_layer=norm_layer, - activation_layer=activation_layer, - **kwargs) + kwargs["num_classes"] = kwargs.get("num_classes", 1000) + super().__init__( + inverted_residual_setting, + last_channel=last_channel, + block=block, + norm_layer=norm_layer, + activation_layer=activation_layer, + **kwargs, + ) def _mobilenet_v3_lite_conf(arch: str, **params: Dict[str, Any]): @@ -411,7 +483,7 @@ def _mobilenet_v3_lite_model( last_channel: int, pretrained: bool, progress: bool, - **kwargs: Any + **kwargs: Any, ): model = MobileNetV3Lite(inverted_residual_setting, last_channel, **kwargs) if pretrained is True: @@ -424,8 +496,8 @@ def _mobilenet_v3_lite_model( model.load_state_dict(state_dict) elif isinstance(pretrained, str): state_dict = torch.load(pretrained) - state_dict = state_dict['model'] if 'model' in state_dict else state_dict - state_dict = state_dict['state_dict'] if 'state_dict' in state_dict else state_dict + state_dict = state_dict["model"] if "model" in state_dict else state_dict + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict model.load_state_dict(state_dict) return model