Skip to content

Commit

Permalink
Placeholder for new mnv3 model
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Aug 23, 2024
1 parent ed7aaf8 commit 76b0e99
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions timm/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ def _gen_mobilenet_v3_rw(


def _gen_mobilenet_v3(
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs
variant: str, channel_multiplier: float = 1.0, depth_multiplier: float = 1.0,
group_size=None, pretrained: bool = False, **kwargs
) -> MobileNetV3:
"""Creates a MobileNet-V3 model.
Expand Down Expand Up @@ -537,7 +538,7 @@ def _gen_mobilenet_v3(
]
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, group_size=group_size),
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, group_size=group_size),
num_features=num_features,
stem_size=16,
fix_stem=channel_multiplier < 0.75,
Expand Down Expand Up @@ -927,6 +928,9 @@ def _cfg(url: str = '', **kwargs):
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
paper_ids='arXiv:2104.10972v4',
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
'mobilenetv3_large_150d.untrained': _cfg(
#hf_hub_id='timm/',
),

'mobilenetv3_small_050.lamb_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
Expand Down Expand Up @@ -1099,6 +1103,11 @@ def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model

@register_model
def mobilenetv3_large_150d(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_large_150d', 1.5, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model

@register_model
def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
Expand Down

0 comments on commit 76b0e99

Please sign in to comment.