Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyang07 committed Dec 14, 2022
1 parent b704a88 commit dcb67f7
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 61 deletions.
2 changes: 1 addition & 1 deletion configs/_base_/nas_backbones/ofa_mobilenetv3_supernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@
with_se_list=[False, False, True, False, True, True],
act_cfg_list=['ReLU', 'ReLU', 'ReLU', 'HSwish', 'HSwish', 'HSwish'],
conv_cfg=dict(type='OFAConv2d'),
norm_cfg=dict(type='mmrazor.DynamicBatchNorm2d', momentum=0.0),
norm_cfg=dict(type='mmrazor.DynamicBatchNorm2d', momentum=0.1),
fine_grained_mode=True,
with_attentive_shortcut=False)
14 changes: 2 additions & 12 deletions mmrazor/structures/subnet/fix_subnet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Dict

from mmengine import fileio
from mmengine.logging import print_log
from torch import nn

from mmrazor.utils import FixMutable, ValidFixMutable
Expand Down Expand Up @@ -99,15 +97,8 @@ def load_fix_module(module):
_dynamic_to_static(model)


def export_fix_subnet(model: nn.Module,
dump_derived_mutable: bool = False) -> FixMutable:
def export_fix_subnet(model: nn.Module) -> FixMutable:
"""Export subnet that can be loaded by :func:`load_fix_subnet`."""
if dump_derived_mutable:
print_log(
'Trying to dump information of all derived mutables, '
'this might harm readability of the exported configurations.',
level=logging.WARNING)

# Avoid circular import
from mmrazor.models.mutables import DerivedMutable, MutableChannelContainer
from mmrazor.models.mutables.base_mutable import BaseMutable
Expand All @@ -121,8 +112,7 @@ def module_dump_chosen(module, fix_subnet):
fix_subnet: Dict[str, DumpChosen] = dict()
for name, module in model.named_modules():
if isinstance(module, BaseMutable):
if isinstance(module, MutableChannelContainer) and \
not dump_derived_mutable:
if isinstance(module, MutableChannelContainer):
continue

elif isinstance(module, DerivedMutable):
Expand Down
55 changes: 31 additions & 24 deletions tests/data/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def __init__(
self.with_attentive_shortcut = True
self.in_channels = 24

self.first_out_channels_list = [16]
self.first_conv = ConvModule(
in_channels=3,
out_channels=24,
Expand All @@ -750,8 +751,6 @@ def __init__(
norm_cfg=self.norm_cfg,
act_cfg=dict(type='Swish'))

self.last_mutable = OneShotMutableChannel(num_channels=24, candidate_choices=[16, 24])

self.layers = []
for i, (num_blocks, kernel_sizes, expand_ratios, num_channels) in \
enumerate(zip(self.num_blocks_list, self.kernel_size_list,
Expand Down Expand Up @@ -833,60 +832,68 @@ def _make_single_layer(self, out_channels, num_blocks,
return dynamic_seq

def register_mutables(self):
"""Mutate the BigNAS-style MobileNetV3."""
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)

# mutate the first conv
self.first_mutable_channels = OneShotMutableChannel(
alias='backbone.first_channels',
num_channels=max(self.first_out_channels_list),
candidate_choices=self.first_out_channels_list)

mutate_conv_module(
self.first_conv, mutable_out_channels=self.last_mutable)
self.first_conv, mutable_out_channels=self.first_mutable_channels)

mid_mutable = self.first_mutable_channels
# mutate the built mobilenet layers
for i, layer in enumerate(self.layers[:-1]):
num_blocks = self.num_blocks_list[i]
kernel_sizes = self.kernel_size_list[i]
expand_ratios = self.expand_ratio_list[i]
out_channels = self.num_channels_list[i]

prefix = 'backbone.layers.' + str(i + 1) + '.'

mutable_out_channels = OneShotMutableChannel(
alias=prefix + 'out_channels',
candidate_choices=out_channels,
num_channels=max(out_channels))

mutable_kernel_size = OneShotMutableValue(
value_list=kernel_sizes, default_value=max(kernel_sizes))
alias=prefix + 'kernel_size', value_list=kernel_sizes)

mutable_expand_ratio = OneShotMutableValue(
value_list=expand_ratios, default_value=max(expand_ratios))
mutable_out_channels = OneShotMutableChannel(
num_channels=max(out_channels), candidate_choices=out_channels)
alias=prefix + 'expand_ratio', value_list=expand_ratios)

se_ratios = [i / 4 for i in expand_ratios]
mutable_se_channels = OneShotMutableValue(
value_list=se_ratios, default_value=max(se_ratios))
mutable_depth = OneShotMutableValue(
alias=prefix + 'depth', value_list=num_blocks)
layer.register_mutable_attr('depth', mutable_depth)

for k in range(max(self.num_blocks_list[i])):
mutate_mobilenet_layer(layer[k], self.last_mutable,
mutate_mobilenet_layer(layer[k], mid_mutable,
mutable_out_channels,
mutable_se_channels,
mutable_expand_ratio,
mutable_kernel_size)
self.last_mutable = mutable_out_channels

mutable_depth = OneShotMutableValue(
value_list=num_blocks, default_value=max(num_blocks))
layer.register_mutable_attr('depth', mutable_depth)
mid_mutable = mutable_out_channels

mutable_out_channels = OneShotMutableChannel(
self.last_mutable_channels = OneShotMutableChannel(
alias='backbone.last_channels',
num_channels=self.out_channels,
candidate_choices=self.last_out_channels_list)

last_mutable_expand_value = OneShotMutableValue(
value_list=self.last_expand_ratio_list,
default_value=max(self.last_expand_ratio_list))
derived_expand_channels = self.last_mutable * last_mutable_expand_value

derived_expand_channels = mid_mutable * last_mutable_expand_value
mutate_conv_module(
self.layers[-1].final_expand_layer,
mutable_in_channels=self.last_mutable,
mutable_in_channels=mid_mutable,
mutable_out_channels=derived_expand_channels)
mutate_conv_module(
self.layers[-1].feature_mix_layer,
mutable_in_channels=derived_expand_channels,
mutable_out_channels=mutable_out_channels)

self.last_mutable = mutable_out_channels
mutable_out_channels=self.last_mutable_channels)

def forward(self, x):
x = self.first_conv(x)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_models/test_algorithms/test_autoslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def test_autoslim_train_step(self) -> None:
assert losses['max_subnet.loss'] > 0
assert losses['min_subnet.loss'] > 0
assert losses['min_subnet.loss_kl'] + 1e-5 > 0
assert losses['random_subnet_0.loss'] > 0
assert losses['random_subnet_0.loss_kl'] + 1e-5 > 0
assert losses['random_subnet_1.loss'] > 0
assert losses['random_subnet_1.loss_kl'] + 1e-5 > 0
assert losses['random0_subnet.loss'] > 0
assert losses['random0_subnet.loss_kl'] + 1e-5 > 0
assert losses['random1_subnet.loss'] > 0
assert losses['random1_subnet.loss_kl'] + 1e-5 > 0

assert algo._optim_wrapper_count_status_reinitialized
assert optim_wrapper._inner_count == 4
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_algorithms/test_bignas.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
label_smooth_val=0.1,
loss_weight=1.0),
topk=(1, 5)),
connect_head=dict(connect_with_backbone='backbone.last_mutable'),
connect_head=dict(connect_with_backbone='backbone.last_mutable_channels'),
)

ALGORITHM_CFG = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_attentive_mobilenet_mutable() -> None:
elif isinstance(module, DynamicSequential):
assert isinstance(module.mutable_depth, OneShotMutableValue)

assert backbone.last_mutable.num_channels == max(out_channels[-1])
assert backbone.last_mutable_channels.num_channels == max(out_channels[-1])


def test_attentive_mobilenet_train() -> None:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_models/test_mutators/test_value_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def test_models_with_multiple_value(self):
for each_mutables in module.source_mutables:
if isinstance(each_mutables, MutableValue):
mutable_value_space.append(each_mutables)
assert len(
value_mutator.search_groups) == len(mutable_value_space)
count = 0
for values in value_mutator.search_groups.values():
count += len(values)
assert count == len(mutable_value_space)

x = torch.rand([2, 3, 224, 224])
y = model(x)
Expand Down
21 changes: 5 additions & 16 deletions tests/test_models/test_subnet/test_fix_subnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,27 +116,16 @@ def test_export_fix_subnet(self):
def test_export_fix_subnet_with_derived_mutable(self) -> None:
model = MockModelWithDerivedMutable()
fix_subnet = export_fix_subnet(model)
self.assertDictEqual(
fix_subnet, {'source_mutable': model.source_mutable.dump_chosen()})

fix_subnet['source_mutable'] = dict(
fix_subnet['source_mutable']._asdict())
fix_subnet['source_mutable']['chosen'] = 4
load_fix_subnet(model, fix_subnet)
assert model.source_mutable.current_choice == 4
assert model.derived_mutable.current_choice == 8

model = MockModelWithDerivedMutable()
fix_subnet = export_fix_subnet(model, dump_derived_mutable=True)
self.assertDictEqual(
fix_subnet, {
'source_mutable': model.source_mutable.dump_chosen(),
'derived_mutable': model.derived_mutable.dump_chosen()
'derived_mutable': model.source_mutable.dump_chosen()
})

fix_subnet['source_mutable'] = dict(
fix_subnet['source_mutable']._asdict())
fix_subnet['source_mutable']['chosen'] = 2
fix_subnet['source_mutable']['chosen'] = 4
load_fix_subnet(model, fix_subnet)
assert model.source_mutable.current_choice == 2
assert model.derived_mutable.current_choice == 4

assert model.source_mutable.current_choice == 4
assert model.derived_mutable.current_choice == 8

0 comments on commit dcb67f7

Please sign in to comment.