Skip to content
This repository has been archived by the owner on May 9, 2024. It is now read-only.

fix backprop issue in mobilenetv3lite #59

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 142 additions & 70 deletions references/edgeailite/edgeai_xvision/xvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,27 @@
#
#################################################################################

from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence

import torch

from functools import partial
from edgeai_torchmodelopt import xnn
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Any, Callable, Dict, List, Optional, Sequence

from .utils import load_state_dict_from_url
from .mobilenetv2 import _make_divisible, ConvBNActivation

from edgeai_torchmodelopt import xnn
from .utils import load_state_dict_from_url

__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small", "MobileNetV3Lite", "mobilenet_v3_lite_large", "mobilenet_v3_lite_small"]
__all__ = [
"MobileNetV3",
"mobilenet_v3_large",
"mobilenet_v3_small",
"MobileNetV3Lite",
"mobilenet_v3_lite_large",
"mobilenet_v3_lite_small",
]


model_urls = {
Expand All @@ -89,8 +97,8 @@ def get_config():
model_config = xnn.utils.ConfigNode()
model_config.input_channels = 3
model_config.num_classes = 1000
model_config.width_mult = 1.
model_config.strides = None #(2,2,2,2,2)
model_config.width_mult = 1.0
model_config.strides = None # (2,2,2,2,2)
model_config.enable_fp16 = False
return model_config

Expand Down Expand Up @@ -118,8 +126,18 @@ def forward(self, input: Tensor) -> Tensor:

class InvertedResidualConfig:
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
activation: str, stride: int, dilation: int, width_mult: float):
def __init__(
self,
input_channels: int,
kernel: int,
expanded_channels: int,
out_channels: int,
use_se: bool,
activation: str,
stride: int,
dilation: int,
width_mult: float,
):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
Expand All @@ -136,11 +154,15 @@ def adjust_channels(channels: int, width_mult: float):

class InvertedResidual(nn.Module):
# Implemented as described at section 5 of MobileNetV3 paper
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
def __init__(
self,
cnf: InvertedResidualConfig,
norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation,
):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value')
raise ValueError("illegal stride value")

self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels

Expand All @@ -149,20 +171,43 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod

# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(
ConvBNActivation(
cnf.input_channels,
cnf.expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)

# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(
ConvBNActivation(
cnf.expanded_channels,
cnf.expanded_channels,
kernel_size=cnf.kernel,
stride=stride,
dilation=cnf.dilation,
groups=cnf.expanded_channels,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
if cnf.use_se:
layers.append(se_layer(cnf.expanded_channels))

# project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.Identity))
layers.append(
ConvBNActivation(
cnf.expanded_channels,
cnf.out_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Identity,
)
)

self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
Expand All @@ -176,16 +221,15 @@ def forward(self, input: Tensor) -> Tensor:


class MobileNetV3(nn.Module):

def __init__(
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
#num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = nn.Hardswish,
**kwargs: Dict
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
# num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = nn.Hardswish,
**kwargs: Dict,
) -> None:
"""
MobileNet V3 main class
Expand All @@ -196,20 +240,23 @@ def __init__(
num_classes (int): Number of classes
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
activation_layer (Optional[Callable[..., nn.Module]]): Module specifying the acitvation layer to use
"""
model_config = get_config()
if 'model_config' in list(kwargs.keys()):
model_config = model_config.merge_from(kwargs['model_config'])
if "model_config" in list(kwargs.keys()):
model_config = model_config.merge_from(kwargs["model_config"])
#
strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2)
strides = model_config.strides if (model_config.strides is not None) else (2, 2, 2, 2, 2)
super().__init__()
self.num_classes = model_config.num_classes
self.enable_fp16 = model_config.enable_fp16

if not inverted_residual_setting:
raise ValueError("The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and
all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
elif not (
isinstance(inverted_residual_setting, Sequence)
and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
):
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")

if block is None:
Expand All @@ -222,8 +269,16 @@ def __init__(

# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvBNActivation(model_config.input_channels, firstconv_output_channels, kernel_size=3, stride=strides[0], norm_layer=norm_layer,
activation_layer=activation_layer))
layers.append(
ConvBNActivation(
model_config.input_channels,
firstconv_output_channels,
kernel_size=3,
stride=strides[0],
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)

# building inverted residual blocks
for cnf in inverted_residual_setting:
Expand All @@ -232,21 +287,28 @@ def __init__(
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(
ConvBNActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)

self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Linear(lastconv_output_channels, last_channel),
activation_layer(inplace=False),
nn.Dropout(p=0.2, inplace=True),
nn.Dropout(p=0.2, inplace=False),
nn.Linear(last_channel, model_config.num_classes),
)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
Expand All @@ -271,18 +333,18 @@ def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str='HS', **params: Dict[str, Any]):
def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str = "HS", **params: Dict[str, Any]):
# non-public config parameters
reduce_divider = 2 if params.pop('_reduced_tail', False) else 1
dilation = 2 if params.pop('_dilated', False) else 1
width_mult = params.pop('_width_mult', 1.0)
reduce_divider = 2 if params.pop("_reduced_tail", False) else 1
dilation = 2 if params.pop("_dilated", False) else 1
width_mult = params.pop("_width_mult", 1.0)

model_config = get_config()
if 'model_config' in list(params.keys()):
model_config = model_config.merge_from(params['model_config'])
if "model_config" in list(params.keys()):
model_config = model_config.merge_from(params["model_config"])
width_mult = max(width_mult, model_config.width_mult)
#
strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2)
strides = model_config.strides if (model_config.strides is not None) else (2, 2, 2, 2, 2)

bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
Expand All @@ -302,8 +364,12 @@ def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str='HS', **para
bneck_conf(80, 3, 480, 112, use_se, hs_type, 1, 1),
bneck_conf(112, 3, 672, 112, use_se, hs_type, 1, 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, use_se, hs_type, strides[4], dilation), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation),
bneck_conf(
160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation
),
bneck_conf(
160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation
),
]
last_channel = adjust_channels(1280 // reduce_divider) # C5
elif arch in ("mobilenet_v3_small", "mobilenet_v3_lite_small"):
Expand All @@ -317,8 +383,12 @@ def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str='HS', **para
bneck_conf(40, 5, 120, 48, use_se, hs_type, 1, 1),
bneck_conf(48, 5, 144, 48, use_se, hs_type, 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, use_se, hs_type, strides[4], dilation), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation),
bneck_conf(
96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation
),
bneck_conf(
96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation
),
]
last_channel = adjust_channels(1024 // reduce_divider) # C5
else:
Expand All @@ -333,7 +403,7 @@ def _mobilenet_v3_model(
last_channel: int,
pretrained: bool,
progress: bool,
**kwargs: Any
**kwargs: Any,
):
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained is True:
Expand All @@ -346,8 +416,8 @@ def _mobilenet_v3_model(
model.load_state_dict(state_dict)
elif isinstance(pretrained, str):
state_dict = torch.load(pretrained)
state_dict = state_dict['model'] if 'model' in state_dict else state_dict
state_dict = state_dict['state_dict'] if 'state_dict' in state_dict else state_dict
state_dict = state_dict["model"] if "model" in state_dict else state_dict
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
model.load_state_dict(state_dict)

return model
Expand Down Expand Up @@ -384,21 +454,23 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
###################################################
class MobileNetV3Lite(MobileNetV3):
def __init__(
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = nn.ReLU,
**kwargs
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = nn.ReLU,
**kwargs,
) -> None:
kwargs['num_classes'] = kwargs.get('num_classes', 1000)
super().__init__(inverted_residual_setting,
last_channel=last_channel,
block=block,
norm_layer=norm_layer,
activation_layer=activation_layer,
**kwargs)
kwargs["num_classes"] = kwargs.get("num_classes", 1000)
super().__init__(
inverted_residual_setting,
last_channel=last_channel,
block=block,
norm_layer=norm_layer,
activation_layer=activation_layer,
**kwargs,
)


def _mobilenet_v3_lite_conf(arch: str, **params: Dict[str, Any]):
Expand All @@ -411,7 +483,7 @@ def _mobilenet_v3_lite_model(
last_channel: int,
pretrained: bool,
progress: bool,
**kwargs: Any
**kwargs: Any,
):
model = MobileNetV3Lite(inverted_residual_setting, last_channel, **kwargs)
if pretrained is True:
Expand All @@ -424,8 +496,8 @@ def _mobilenet_v3_lite_model(
model.load_state_dict(state_dict)
elif isinstance(pretrained, str):
state_dict = torch.load(pretrained)
state_dict = state_dict['model'] if 'model' in state_dict else state_dict
state_dict = state_dict['state_dict'] if 'state_dict' in state_dict else state_dict
state_dict = state_dict["model"] if "model" in state_dict else state_dict
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
model.load_state_dict(state_dict)

return model
Expand Down