Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] Update estimator with api revision #277

Merged
merged 7 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
num_mutation=25,
num_crossover=25,
mutate_prob=0.1,
flops_range=(0., 465 * 1e6),
flops_range=(0., 465.),
score_key='accuracy/top1')
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
num_mutation=25,
num_crossover=25,
mutate_prob=0.1,
flops_range=(0., 330 * 1e6),
flops_range=(0., 330.),
score_key='accuracy/top1')
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
num_mutation=20,
num_crossover=20,
mutate_prob=0.1,
flops_range=None,
score_key='bbox_mAP')
flops_range=(0., 300.),
score_key='coco/bbox_mAP')
45 changes: 20 additions & 25 deletions mmrazor/engine/runner/evolution_search_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import os.path as osp
import random
Expand All @@ -14,11 +13,11 @@
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader

from mmrazor.models.task_modules.estimators import get_model_complexity_info
from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS
from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet
from mmrazor.structures import Candidates, export_fix_subnet
from mmrazor.utils import SupportRandomSubnet
from .utils import crossover
from .utils import check_subnet_flops, crossover


@LOOPS.register_module()
Expand All @@ -42,10 +41,10 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
num_crossover (int): The number of candidates got by crossover.
Defaults to 25.
mutate_prob (float): The probability of mutation. Defaults to 0.1.
flops_range (tuple, optional): flops_range to be used for screening
candidates.
spec_modules (list): Used for specify modules need to counter.
Defaults to list().
flops_range (tuple, optional): It is used for screening candidates.
resource_estimator_cfg (dict): The config for building estimator, which
is be used to estimate the flops of sampled subnet. Defaults to
None, which means default config is used.
score_key (str): Specify one metric in evaluation results to score
candidates. Defaults to 'accuracy_top-1'.
init_candidates (str, optional): The candidates file path, which is
Expand All @@ -65,8 +64,8 @@ def __init__(self,
num_mutation: int = 25,
num_crossover: int = 25,
mutate_prob: float = 0.1,
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
spec_modules: List = [],
flops_range: Optional[Tuple[float, float]] = (0., 330.),
resource_estimator_cfg: Optional[dict] = None,
score_key: str = 'accuracy/top1',
init_candidates: Optional[str] = None) -> None:
super().__init__(runner, dataloader, max_epochs)
Expand All @@ -85,7 +84,6 @@ def __init__(self,
self.num_candidates = num_candidates
self.top_k = top_k
self.flops_range = flops_range
self.spec_modules = spec_modules
self.score_key = score_key
self.num_mutation = num_mutation
self.num_crossover = num_crossover
Expand All @@ -101,6 +99,10 @@ def __init__(self,
correct init candidates file'

self.top_k_candidates = Candidates()
if resource_estimator_cfg is None:
self.estimator = ResourceEstimator()
else:
self.estimator = ResourceEstimator(**resource_estimator_cfg)

if self.runner.distributed:
self.model = runner.model.module
Expand Down Expand Up @@ -299,17 +301,10 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
Returns:
bool: The result of checking.
"""
if self.flops_range is None:
return True

self.model.set_subnet(random_subnet)
fix_mutable = export_fix_subnet(self.model)
copied_model = copy.deepcopy(self.model)
load_fix_subnet(copied_model, fix_mutable)
flops, _ = get_model_complexity_info(
copied_model, spec_modules=self.spec_modules)

if self.flops_range[0] <= flops <= self.flops_range[1]:
return True
else:
return False
is_pass = check_subnet_flops(
model=self.model,
subnet=random_subnet,
estimator=self.estimator,
flops_range=self.flops_range)

return is_pass
41 changes: 19 additions & 22 deletions mmrazor/engine/runner/subnet_sampler_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import os
import random
Expand All @@ -13,10 +12,11 @@
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader

from mmrazor.models.task_modules.estimators import get_model_complexity_info
from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS
from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet
from mmrazor.structures import Candidates
from mmrazor.utils import SupportRandomSubnet
from .utils import check_subnet_flops


class BaseSamplerTrainLoop(IterBasedTrainLoop):
Expand Down Expand Up @@ -103,8 +103,9 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
score_key (str): Specify one metric in evaluation results to score
candidates. Defaults to 'accuracy_top-1'.
flops_range (dict): Constraints to be used for screening candidates.
spec_modules (list): Used for specify modules need to counter.
Defaults to list().
resource_estimator_cfg (dict): The config for building estimator, which
is be used to estimate the flops of sampled subnet. Defaults to
None, which means default config is used.
num_candidates (int): The number of the candidates consist of samples
from supernet and itself. Defaults to 1000.
num_samples (int): The number of sample in each sampling subnet.
Expand Down Expand Up @@ -138,8 +139,8 @@ def __init__(self,
val_begin: int = 1,
val_interval: int = 1000,
score_key: str = 'accuracy/top1',
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
spec_modules: List = [],
flops_range: Optional[Tuple[float, float]] = (0., 330),
resource_estimator_cfg: Optional[dict] = None,
num_candidates: int = 1000,
num_samples: int = 10,
top_k: int = 5,
Expand All @@ -163,7 +164,6 @@ def __init__(self,

self.score_key = score_key
self.flops_range = flops_range
self.spec_modules = spec_modules
self.num_candidates = num_candidates
self.num_samples = num_samples
self.top_k = top_k
Expand All @@ -177,6 +177,10 @@ def __init__(self,

self.candidates = Candidates()
self.top_k_candidates = Candidates()
if resource_estimator_cfg is None:
self.estimator = ResourceEstimator()
else:
self.estimator = ResourceEstimator(**resource_estimator_cfg)

def run(self) -> None:
"""Launch training."""
Expand Down Expand Up @@ -317,20 +321,13 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
Returns:
bool: The result of checking.
"""
if self.flops_range is None:
return True

self.model.set_subnet(random_subnet)
fix_mutable = export_fix_subnet(self.model)
copied_model = copy.deepcopy(self.model)
load_fix_subnet(copied_model, fix_mutable)
flops, _ = get_model_complexity_info(
copied_model, spec_modules=self.spec_modules)

if self.flops_range[0] <= flops <= self.flops_range[1]:
return True
else:
return False
is_pass = check_subnet_flops(
model=self.model,
subnet=random_subnet,
estimator=self.estimator,
flops_range=self.flops_range)

return is_pass

def _save_candidates(self) -> None:
"""Save the candidates to init the next searching."""
Expand Down
3 changes: 2 additions & 1 deletion mmrazor/engine/runner/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .check import check_subnet_flops
from .genetic import crossover

__all__ = ['crossover']
__all__ = ['crossover', 'check_subnet_flops']
43 changes: 43 additions & 0 deletions mmrazor/engine/runner/utils/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional, Tuple

import torch.nn as nn
from mmdet.models.detectors import BaseDetector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to use PlaceHolder for downstream repos.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


from mmrazor.models import ResourceEstimator
from mmrazor.structures import export_fix_subnet, load_fix_subnet
from mmrazor.utils import SupportRandomSubnet


def check_subnet_flops(
model: nn.Module,
subnet: SupportRandomSubnet,
estimator: ResourceEstimator,
flops_range: Optional[Tuple[float, float]] = None) -> bool:
"""Check whether is beyond flops constraints.

Returns:
bool: The result of checking.
"""
if flops_range is None:
return True

assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture')
model.set_subnet(subnet)
fix_mutable = export_fix_subnet(model)
copied_model = copy.deepcopy(model)
load_fix_subnet(copied_model, fix_mutable)

model_to_check = model.architecture
if isinstance(model_to_check, BaseDetector):
results = estimator.estimate(model=model_to_check.backbone)
else:
results = estimator.estimate(model=model_to_check)

flops = results['flops']
flops_mix, flops_max = flops_range
if flops_mix <= flops <= flops_max: # type: ignore
return True
else:
return False
44 changes: 20 additions & 24 deletions mmrazor/models/task_modules/estimators/base_estimator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Tuple
from typing import Dict, Tuple, Union

import torch.nn

Expand All @@ -12,44 +12,40 @@ class BaseEstimator(metaclass=ABCMeta):
"""The base class of Estimator, used for estimating model infos.

Args:
default_shape (tuple): Input data's default shape, for calculating
input_shape (tuple): Input data's default shape, for calculating
resources consume. Defaults to (1, 3, 224, 224).
units (str): Resource units. Defaults to 'M'.
disabled_counters (list): List of disabled spec op counters.
Defaults to None.
units (dict): A dict including required units. Default to dict().
as_strings (bool): Output FLOPs and params counts in a string
form. Default to False.
measure_inference (bool): whether to measure infer speed or not.
Default to False.
"""

def __init__(self,
default_shape: Tuple = (1, 3, 224, 224),
units: str = 'M',
disabled_counters: List[str] = None,
as_strings: bool = False,
measure_inference: bool = False):
assert len(default_shape) in [3, 4, 5], \
f'Unsupported shape: {default_shape}'
self.default_shape = default_shape
input_shape: Tuple = (1, 3, 224, 224),
units: Dict = dict(),
as_strings: bool = False):
assert len(input_shape) in [
3, 4, 5
], ('The length of input_shape must be in [3, 4, 5]. '
f'Got `{len(input_shape)}`.')
self.input_shape = input_shape
self.units = units
self.disabled_counters = disabled_counters
self.as_strings = as_strings
self.measure_inference = measure_inference

@abstractmethod
def estimate(
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
) -> Dict[str, float]:
def estimate(self,
model: torch.nn.Module,
flops_params_cfg: dict = None,
latency_cfg: dict = None) -> Dict[str, Union[float, str]]:
"""Estimate the resources(flops/params/latency) of the given model.

Args:
model: The measured model.
resource_args (Dict[str, float]): resources information.
NOTE: resource_args have the same items() as the init cfgs.
flops_params_cfg (dict): Cfg for estimating FLOPs and parameters.
Default to None.
latency_cfg (dict): Cfg for estimating latency. Default to None.

Returns:
Dict[str, float]): A dict that containing resource results(flops,
params and latency).
Dict[str, Union[float, str]]): A dict that contains the resource
results(FLOPs, params and latency).
"""
pass
10 changes: 3 additions & 7 deletions mmrazor/models/task_modules/estimators/counters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .flops_params_counter import (get_model_complexity_info,
params_units_convert)
from .latency_counter import repeat_measure_inference_speed
from .flops_params_counter import get_model_flops_params
from .latency_counter import get_model_latency
from .op_counters import * # noqa: F401,F403

__all__ = [
'get_model_complexity_info', 'params_units_convert',
'repeat_measure_inference_speed'
]
__all__ = ['get_model_flops_params', 'get_model_latency']
Loading