diff --git a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py index f670e5c0c..4f5edb316 100644 --- a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 465 * 1e6), + flops_range=(0., 465.), score_key='accuracy/top1') diff --git a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py index 6f8dc9366..f3f963e40 100644 --- a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 330 * 1e6), + flops_range=(0., 330.), score_key='accuracy/top1') diff --git a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py index 0bd3b71fe..d1dd1637a 100644 --- a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py +++ b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py @@ -13,5 +13,5 @@ num_mutation=20, num_crossover=20, mutate_prob=0.1, - flops_range=None, - score_key='bbox_mAP') + flops_range=(0., 300.), + score_key='coco/bbox_mAP') diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index a72704e40..a9a76b383 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import os import os.path as osp import random @@ -14,11 +13,11 @@ from mmengine.utils import is_list_of from torch.utils.data import DataLoader -from mmrazor.models.task_modules.estimators import get_model_complexity_info +from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet +from mmrazor.structures import Candidates, export_fix_subnet from mmrazor.utils import SupportRandomSubnet -from .utils import crossover +from .utils import check_subnet_flops, crossover @LOOPS.register_module() @@ -42,10 +41,10 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): num_crossover (int): The number of candidates got by crossover. Defaults to 25. mutate_prob (float): The probability of mutation. Defaults to 0.1. - flops_range (tuple, optional): flops_range to be used for screening - candidates. - spec_modules (list): Used for specify modules need to counter. - Defaults to list(). + flops_range (tuple, optional): It is used for screening candidates. + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. init_candidates (str, optional): The candidates file path, which is @@ -65,8 +64,8 @@ def __init__(self, num_mutation: int = 25, num_crossover: int = 25, mutate_prob: float = 0.1, - flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6), - spec_modules: List = [], + flops_range: Optional[Tuple[float, float]] = (0., 330.), + resource_estimator_cfg: Optional[dict] = None, score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: super().__init__(runner, dataloader, max_epochs) @@ -85,7 +84,6 @@ def __init__(self, self.num_candidates = num_candidates self.top_k = top_k self.flops_range = flops_range - self.spec_modules = spec_modules self.score_key = score_key self.num_mutation = num_mutation self.num_crossover = num_crossover @@ -101,6 +99,10 @@ def __init__(self, correct init candidates file' self.top_k_candidates = Candidates() + if resource_estimator_cfg is None: + self.estimator = ResourceEstimator() + else: + self.estimator = ResourceEstimator(**resource_estimator_cfg) if self.runner.distributed: self.model = runner.model.module @@ -299,17 +301,10 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: Returns: bool: The result of checking. """ - if self.flops_range is None: - return True - - self.model.set_subnet(random_subnet) - fix_mutable = export_fix_subnet(self.model) - copied_model = copy.deepcopy(self.model) - load_fix_subnet(copied_model, fix_mutable) - flops, _ = get_model_complexity_info( - copied_model, spec_modules=self.spec_modules) - - if self.flops_range[0] <= flops <= self.flops_range[1]: - return True - else: - return False + is_pass = check_subnet_flops( + model=self.model, + subnet=random_subnet, + estimator=self.estimator, + flops_range=self.flops_range) + + return is_pass diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index c2b4d2176..1127aab21 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import math import os import random @@ -13,10 +12,11 @@ from mmengine.utils import is_list_of from torch.utils.data import DataLoader -from mmrazor.models.task_modules.estimators import get_model_complexity_info +from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet +from mmrazor.structures import Candidates from mmrazor.utils import SupportRandomSubnet +from .utils import check_subnet_flops class BaseSamplerTrainLoop(IterBasedTrainLoop): @@ -103,8 +103,9 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. flops_range (dict): Constraints to be used for screening candidates. - spec_modules (list): Used for specify modules need to counter. - Defaults to list(). + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. num_candidates (int): The number of the candidates consist of samples from supernet and itself. Defaults to 1000. num_samples (int): The number of sample in each sampling subnet. @@ -138,8 +139,8 @@ def __init__(self, val_begin: int = 1, val_interval: int = 1000, score_key: str = 'accuracy/top1', - flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6), - spec_modules: List = [], + flops_range: Optional[Tuple[float, float]] = (0., 330), + resource_estimator_cfg: Optional[dict] = None, num_candidates: int = 1000, num_samples: int = 10, top_k: int = 5, @@ -163,7 +164,6 @@ def __init__(self, self.score_key = score_key self.flops_range = flops_range - self.spec_modules = spec_modules self.num_candidates = num_candidates self.num_samples = num_samples self.top_k = top_k @@ -177,6 +177,10 @@ def __init__(self, self.candidates = Candidates() self.top_k_candidates = Candidates() + if resource_estimator_cfg is None: + self.estimator = ResourceEstimator() + else: + self.estimator = ResourceEstimator(**resource_estimator_cfg) def run(self) -> None: """Launch training.""" @@ -317,20 +321,13 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: Returns: bool: The result of checking. """ - if self.flops_range is None: - return True - - self.model.set_subnet(random_subnet) - fix_mutable = export_fix_subnet(self.model) - copied_model = copy.deepcopy(self.model) - load_fix_subnet(copied_model, fix_mutable) - flops, _ = get_model_complexity_info( - copied_model, spec_modules=self.spec_modules) - - if self.flops_range[0] <= flops <= self.flops_range[1]: - return True - else: - return False + is_pass = check_subnet_flops( + model=self.model, + subnet=random_subnet, + estimator=self.estimator, + flops_range=self.flops_range) + + return is_pass def _save_candidates(self) -> None: """Save the candidates to init the next searching.""" diff --git a/mmrazor/engine/runner/utils/__init__.py b/mmrazor/engine/runner/utils/__init__.py index 7aaf29539..ec2f2cb29 100644 --- a/mmrazor/engine/runner/utils/__init__.py +++ b/mmrazor/engine/runner/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .check import check_subnet_flops from .genetic import crossover -__all__ = ['crossover'] +__all__ = ['crossover', 'check_subnet_flops'] diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py new file mode 100644 index 000000000..e2fdcfcc6 --- /dev/null +++ b/mmrazor/engine/runner/utils/check.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional, Tuple + +import torch.nn as nn + +from mmrazor.models import ResourceEstimator +from mmrazor.structures import export_fix_subnet, load_fix_subnet +from mmrazor.utils import SupportRandomSubnet + +try: + from mmdet.models.detectors import BaseDetector +except ImportError: + from mmrazor.utils import get_placeholder + BaseDetector = get_placeholder('mmdet') + + +def check_subnet_flops( + model: nn.Module, + subnet: SupportRandomSubnet, + estimator: ResourceEstimator, + flops_range: Optional[Tuple[float, float]] = None) -> bool: + """Check whether is beyond flops constraints. + + Returns: + bool: The result of checking. + """ + if flops_range is None: + return True + + assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture') + model.set_subnet(subnet) + fix_mutable = export_fix_subnet(model) + copied_model = copy.deepcopy(model) + load_fix_subnet(copied_model, fix_mutable) + + model_to_check = model.architecture + if isinstance(model_to_check, BaseDetector): + results = estimator.estimate(model=model_to_check.backbone) + else: + results = estimator.estimate(model=model_to_check) + + flops = results['flops'] + flops_mix, flops_max = flops_range + if flops_mix <= flops <= flops_max: # type: ignore + return True + else: + return False diff --git a/mmrazor/models/task_modules/estimators/base_estimator.py b/mmrazor/models/task_modules/estimators/base_estimator.py index 22a82d105..1a6f69264 100644 --- a/mmrazor/models/task_modules/estimators/base_estimator.py +++ b/mmrazor/models/task_modules/estimators/base_estimator.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Tuple +from typing import Dict, Tuple, Union import torch.nn @@ -12,44 +12,40 @@ class BaseEstimator(metaclass=ABCMeta): """The base class of Estimator, used for estimating model infos. Args: - default_shape (tuple): Input data's default shape, for calculating + input_shape (tuple): Input data's default shape, for calculating resources consume. Defaults to (1, 3, 224, 224). - units (str): Resource units. Defaults to 'M'. - disabled_counters (list): List of disabled spec op counters. - Defaults to None. + units (dict): A dict including required units. Default to dict(). as_strings (bool): Output FLOPs and params counts in a string form. Default to False. - measure_inference (bool): whether to measure infer speed or not. - Default to False. """ def __init__(self, - default_shape: Tuple = (1, 3, 224, 224), - units: str = 'M', - disabled_counters: List[str] = None, - as_strings: bool = False, - measure_inference: bool = False): - assert len(default_shape) in [3, 4, 5], \ - f'Unsupported shape: {default_shape}' - self.default_shape = default_shape + input_shape: Tuple = (1, 3, 224, 224), + units: Dict = dict(), + as_strings: bool = False): + assert len(input_shape) in [ + 3, 4, 5 + ], ('The length of input_shape must be in [3, 4, 5]. ' + f'Got `{len(input_shape)}`.') + self.input_shape = input_shape self.units = units - self.disabled_counters = disabled_counters self.as_strings = as_strings - self.measure_inference = measure_inference @abstractmethod - def estimate( - self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() - ) -> Dict[str, float]: + def estimate(self, + model: torch.nn.Module, + flops_params_cfg: dict = None, + latency_cfg: dict = None) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. Args: model: The measured model. - resource_args (Dict[str, float]): resources information. - NOTE: resource_args have the same items() as the init cfgs. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. + latency_cfg (dict): Cfg for estimating latency. Default to None. Returns: - Dict[str, float]): A dict that containing resource results(flops, - params and latency). + Dict[str, Union[float, str]]): A dict that contains the resource + results(FLOPs, params and latency). """ pass diff --git a/mmrazor/models/task_modules/estimators/counters/__init__.py b/mmrazor/models/task_modules/estimators/counters/__init__.py index 0a6adee48..721987ec1 100644 --- a/mmrazor/models/task_modules/estimators/counters/__init__.py +++ b/mmrazor/models/task_modules/estimators/counters/__init__.py @@ -1,10 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .flops_params_counter import (get_model_complexity_info, - params_units_convert) -from .latency_counter import repeat_measure_inference_speed +from .flops_params_counter import get_model_flops_params +from .latency_counter import get_model_latency from .op_counters import * # noqa: F401,F403 -__all__ = [ - 'get_model_complexity_info', 'params_units_convert', - 'repeat_measure_inference_speed' -] +__all__ = ['get_model_flops_params', 'get_model_latency'] diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index 31e998a2a..f31208248 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import sys from functools import partial +from typing import Dict import torch import torch.nn as nn @@ -8,19 +9,21 @@ from mmrazor.registry import TASK_UTILS -def get_model_complexity_info(model, - input_shape=(1, 3, 224, 224), - spec_modules=[], - disabled_counters=[], - print_per_layer_stat=False, - as_strings=False, - input_constructor=None, - flush=False, - ost=sys.stdout): - """Get complexity information of a model. This method can calculate FLOPs - and parameter counts of a model with corresponding input shape. It can also - print complexity information for each layer in a model. Supported layers - are listed as below: +def get_model_flops_params(model, + input_shape=(1, 3, 224, 224), + spec_modules=[], + disabled_counters=[], + print_per_layer_stat=False, + units=dict(flops='M', params='M'), + as_strings=False, + seperate_return: bool = False, + input_constructor=None, + flush=False, + ost=sys.stdout): + """Get FLOPs and parameters of a model. This method can calculate FLOPs and + parameter counts of a model with corresponding input shape. It can also + print FLOPs and params for each layer in a model. Supported layers are + listed as below: - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``, @@ -39,16 +42,20 @@ def get_model_complexity_info(model, Args: model (nn.Module): The model for complexity calculation. input_shape (tuple): Input shape (including batchsize) used for - calculation. Default to (1, 3, 224, 224) + calculation. Default to (1, 3, 224, 224). spec_modules (list): A list that contains the names of several spec modules, which users want to get resources infos of them. e.g., ['backbone', 'head'], ['backbone.layer1']. Default to []. disabled_counters (list): One can limit which ops' spec would be calculated. Default to []. - print_per_layer_stat (bool): Whether to print complexity information + print_per_layer_stat (bool): Whether to print FLOPs and params for each layer in a model. Default to True. + units (dict): A dict including converted FLOPs and params units. + Default to dict(flops='M', params='M'). as_strings (bool): Output FLOPs and params counts in a string form. Default to True. + seperate_return (bool): Whether to return the resource information + separately. Default to False. input_constructor (None | callable): If specified, it takes a callable method that generates input. otherwise, it will generate a random tensor with input shape to calculate FLOPs. Default to None. @@ -60,12 +67,16 @@ def get_model_complexity_info(model, tuple[float | str] | dict[str, float]: If `as_strings` is set to True, it will return FLOPs and parameter counts in a string format. Otherwise, it will return those in a float number format. - If len(spec_modules) > 0, it will return a resource info dict with - FLOPs and parameter counts of each spec module in float format. + NOTE: If seperate_return, it will return a resource info dict with + FLOPs & params counts of each spec module in float|string format. """ assert type(input_shape) is tuple assert len(input_shape) >= 1 assert isinstance(model, nn.Module) + if seperate_return and not len(spec_modules): + raise AssertionError('`seperate_return` can only be set to True when ' + '`spec_modules` are not empty.') + flops_params_model = add_flops_params_counting_methods(model) flops_params_model.eval() flops_params_model.start_flops_params_count(disabled_counters) @@ -96,34 +107,44 @@ def get_model_complexity_info(model, ost=ost, flush=flush) + if units is not None: + flops_count = params_units_convert(flops_count, units['flops']) + params_count = params_units_convert(params_count, units['params']) + + if as_strings: + flops_suffix = ' ' + units['flops'] + 'FLOPs' if units else ' FLOPs' + params_suffix = ' ' + units['params'] if units else '' + if len(spec_modules): + flops_count, params_count = 0.0, 0.0 module_names = [name for name, _ in flops_params_model.named_modules()] for module in spec_modules: assert module in module_names, \ f'All modules in spec_modules should be in the measured ' \ - f'flops_params_model. Got module {module} in spec_modules.' - spec_modules_resources = dict() - accumulate_sub_module_flops_params(flops_params_model) + f'flops_params_model. Got module `{module}` in spec_modules.' + spec_modules_resources: Dict[str, dict] = dict() + accumulate_sub_module_flops_params(flops_params_model, units=units) for name, module in flops_params_model.named_modules(): if name in spec_modules: spec_modules_resources[name] = dict() spec_modules_resources[name]['flops'] = module.__flops__ spec_modules_resources[name]['params'] = module.__params__ + flops_count += module.__flops__ + params_count += module.__params__ if as_strings: - spec_modules_resources[name]['flops'] = str( - params_units_convert(module.__flops__, - 'G')) + ' GFLOPs' - spec_modules_resources[name]['params'] = str( - params_units_convert(module.__params__, 'M')) + ' M' + spec_modules_resources[name]['flops'] = \ + str(module.__flops__) + flops_suffix + spec_modules_resources[name]['params'] = \ + str(module.__params__) + params_suffix flops_params_model.stop_flops_params_count() - if len(spec_modules): + if seperate_return: return spec_modules_resources if as_strings: - flops_string = str(params_units_convert(flops_count, 'G')) + ' GFLOPs' - params_string = str(params_units_convert(params_count, 'M')) + ' M' + flops_string = str(flops_count) + flops_suffix + params_string = str(params_count) + params_suffix return flops_string, params_string return flops_count, params_count @@ -164,7 +185,7 @@ def params_units_convert(num_params, units='M', precision=3): def print_model_with_flops_params(model, total_flops, total_params, - units='G', + units=dict(flops='M', params='M'), precision=3, ost=sys.stdout, flush=False): @@ -174,7 +195,9 @@ def print_model_with_flops_params(model, model (nn.Module): The model to be printed. total_flops (float): Total FLOPs of the model. total_params (float): Total parameter counts of the model. - units (str | None): Converted FLOPs units. Default to 'G'. + units (tuple | none): A tuple pair including converted FLOPs & params + units. e.g., ('G', 'M') stands for FLOPs as 'G' & params as 'M'. + Default to ('M', 'M'). precision (int): Digit number after the decimal point. Default to 3. ost (stream): same as `file` param in :func:`print`. Default to sys.stdout. @@ -200,8 +223,8 @@ def print_model_with_flops_params(model, >>> return x >>> model = ExampleModel() >>> x = (3, 16, 16) - to print the complexity information state for each layer, you can use - >>> get_model_complexity_info(model, x) + to print the FLOPs and params state for each layer, you can use + >>> get_model_flops_params(model, x) or directly use >>> print_model_with_flops_params(model, 4579784.0, 37361) ExampleModel( @@ -241,11 +264,11 @@ def flops_repr(self): accumulated_flops_cost = self.accumulate_flops() flops_string = str( params_units_convert( - accumulated_flops_cost, units=units, - precision=precision)) + ' ' + units + 'FLOPs' + accumulated_flops_cost, units['flops'], + precision=precision)) + ' ' + units['flops'] + 'FLOPs' params_string = str( - params_units_convert( - accumulated_num_params, units='M', precision=precision)) + ' M' + params_units_convert(accumulated_num_params, units['params'], + precision)) + ' M' return ', '.join([ params_string, '{:.3%} Params'.format(accumulated_num_params / total_params), @@ -277,12 +300,15 @@ def del_extra_repr(m): model.apply(del_extra_repr) -def accumulate_sub_module_flops_params(model): +def accumulate_sub_module_flops_params(model, units=None): """Accumulate FLOPs and params for each module in the model. Each module in the model will have the `__flops__` and `__params__` parameters. Args: model (nn.Module): The model to be accumulated. + units (tuple | none): A tuple pair including converted FLOPs & params + units. e.g., ('G', 'M') stands for FLOPs as 'G' & params as 'M'. + Default to None. """ def accumulate_params(module): @@ -310,6 +336,9 @@ def accumulate_flops(module): _params = accumulate_params(module) module.__flops__ = _flops module.__params__ = _params + if units is not None: + module.__flops__ = params_units_convert(_flops, units['flops']) + module.__params__ = params_units_convert(_params, units['params']) def get_model_parameters_number(model): diff --git a/mmrazor/models/task_modules/estimators/counters/latency_counter.py b/mmrazor/models/task_modules/estimators/counters/latency_counter.py index e3e91c54e..a4241e313 100644 --- a/mmrazor/models/task_modules/estimators/counters/latency_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/latency_counter.py @@ -1,71 +1,89 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging import time -from typing import Any, Dict +from typing import Tuple, Union import torch from mmengine.logging import print_log -def repeat_measure_inference_speed(model: torch.nn.Module, - resource_args: Dict[str, Any], - max_iter: int = 100, - num_warmup: int = 5, - log_interval: int = 100, - repeat_num: int = 1) -> float: +def get_model_latency(model: torch.nn.Module, + input_shape: Tuple = (1, 3, 224, 224), + unit: str = 'ms', + as_strings: bool = False, + max_iter: int = 100, + num_warmup: int = 5, + log_interval: int = 100, + repeat_num: int = 1) -> Union[float, str]: """Repeat speed measure for multi-times to get more precise results. Args: model (torch.nn.Module): The measured model. - resource_args (Dict[str, float]): resources information. - max_iter (Optional[int]): Max iteration num for inference speed test. + input_shape (tuple): Input shape (including batchsize) used for + calculation. Default to (1, 3, 224, 224). + unit (str): Unit of latency in string format. Default to 'ms'. + as_strings (bool): Output latency counts in a string form. + Default to False. + max_iter (Optional[int]): Max iteration num for the measurement. + Default to 100. num_warmup (Optional[int]): Iteration num for warm-up stage. + Default to 5. log_interval (Optional[int]): Interval num for logging the results. + Default to 100. repeat_num (Optional[int]): Num of times to repeat the measurement. + Default to 1. Returns: - fps (float): The measured inference speed of the model. + latency (Union[float, str]): The measured inference speed of the model. + if ``as_strings=True``, it will return latency in string format. """ assert repeat_num >= 1 fps_list = [] for _ in range(repeat_num): - fps_list.append( - measure_inference_speed(model, resource_args, max_iter, num_warmup, - log_interval)) + _get_model_latency(model, input_shape, max_iter, num_warmup, + log_interval)) + + latency = round(1000 / fps_list[0], 1) if repeat_num > 1: - fps_list_ = [round(fps, 1) for fps in fps_list] + _fps_list = [round(fps, 1) for fps in fps_list] times_per_img_list = [round(1000 / fps, 1) for fps in fps_list] - mean_fps_ = sum(fps_list_) / len(fps_list_) + _mean_fps = sum(_fps_list) / len(_fps_list) mean_times_per_img = sum(times_per_img_list) / len(times_per_img_list) print_log( - f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, ' + f'Overall fps: {_fps_list}[{_mean_fps:.1f}] img / s, ' f'times per image: ' f'{times_per_img_list}[{mean_times_per_img:.1f}] ms/img', logger='current', level=logging.DEBUG) - return mean_times_per_img + latency = mean_times_per_img + + if as_strings: + latency = str(latency) + ' ' + unit # type: ignore - latency = round(1000 / fps_list[0], 1) return latency -def measure_inference_speed(model: torch.nn.Module, - resource_args: Dict[str, Any], - max_iter: int = 100, - num_warmup: int = 5, - log_interval: int = 100) -> float: +def _get_model_latency(model: torch.nn.Module, + input_shape: Tuple = (1, 3, 224, 224), + max_iter: int = 100, + num_warmup: int = 5, + log_interval: int = 100) -> float: """Measure inference speed on GPU devices. Args: model (torch.nn.Module): The measured model. - resource_args (Dict[str, float]): resources information. - max_iter (Optional[int]): Max iteration num for inference speed test. + input_shape (tuple): Input shape (including batchsize) used for + calculation. Default to (1, 3, 224, 224). + max_iter (Optional[int]): Max iteration num for the measurement. + Default to 100. num_warmup (Optional[int]): Iteration num for warm-up stage. + Default to 5. log_interval (Optional[int]): Interval num for logging the results. + Default to 100. Returns: fps (float): The measured inference speed of the model. @@ -78,10 +96,11 @@ def measure_inference_speed(model: torch.nn.Module, device = 'cuda' else: raise NotImplementedError('To use cpu to test latency not supported.') + # benchmark with {max_iter} image and take the average for i in range(1, max_iter): if device == 'cuda': - data = torch.rand(resource_args['input_shape']).cuda() + data = torch.rand(input_shape).cuda() torch.cuda.synchronize() start_time = time.perf_counter() diff --git a/mmrazor/models/task_modules/estimators/resource_estimator.py b/mmrazor/models/task_modules/estimators/resource_estimator.py index 6d4342866..ac5292d0c 100644 --- a/mmrazor/models/task_modules/estimators/resource_estimator.py +++ b/mmrazor/models/task_modules/estimators/resource_estimator.py @@ -1,13 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, List, Tuple +from typing import Dict, Optional, Tuple, Union import torch.nn -from mmengine.dist import broadcast_object_list, is_main_process from mmrazor.registry import TASK_UTILS from .base_estimator import BaseEstimator -from .counters import (get_model_complexity_info, params_units_convert, - repeat_measure_inference_speed) +from .counters import get_model_flops_params, get_model_latency @TASK_UTILS.register_module() @@ -15,24 +13,30 @@ class ResourceEstimator(BaseEstimator): """Estimator for calculating the resources consume. Args: - default_shape (tuple): Input data's default shape, for calculating - resources consume. Defaults to (1, 3, 224, 224) - units (str): Resource units. Defaults to 'M'. - disabled_counters (list): List of disabled spec op counters. - Defaults to None. - NOTE: disabled_counters contains the op counter class names - in estimator.op_counters that require to be disabled, - such as 'ConvCounter', 'BatchNorm2dCounter', ... + input_shape (tuple): Input data's default shape, for calculating + resources consume. Defaults to (1, 3, 224, 224). + units (dict): Dict that contains converted FLOPs/params/latency units. + Default to dict(flops='M', params='M', latency='ms'). + as_strings (bool): Output FLOPs/params/latency counts in a string + form. Default to False. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. + latency_cfg (dict): Cfg for estimating latency. Default to None. Examples: >>> # direct calculate resource consume of nn.Conv2d >>> conv2d = nn.Conv2d(3, 32, 3) - >>> estimator = ResourceEstimator() - >>> estimator.estimate( - ... model=conv2d, - ... resource_args=dict(input_shape=(1, 3, 64, 64))) + >>> estimator = ResourceEstimator(input_shape=(1, 3, 64, 64)) + >>> estimator.estimate(model=conv2d) {'flops': 3.444, 'params': 0.001, 'latency': 0.0} + >>> # direct calculate resource consume of nn.Conv2d + >>> conv2d = nn.Conv2d(3, 32, 3) + >>> estimator = ResourceEstimator() + >>> flops_params_cfg = dict(input_shape=(1, 3, 32, 32)) + >>> estimator.estimate(model=conv2d, flops_params_cfg) + {'flops': 0.806, 'params': 0.001, 'latency': 0.0} + >>> # calculate resources of custom modules >>> class CustomModule(nn.Module): ... @@ -51,17 +55,14 @@ class ResourceEstimator(BaseEstimator): ... module.__params__ += 700000 ... >>> model = CustomModule() - >>> estimator.estimate( - ... model=model, - ... resource_args=dict(input_shape=(1, 3, 64, 64))) + >>> flops_params_cfg = dict(input_shape=(1, 3, 64, 64)) + >>> estimator.estimate(model=model, flops_params_cfg) {'flops': 1.0, 'params': 0.7, 'latency': 0.0} ... >>> # calculate resources of custom modules with disable_counters - >>> estimator.estimate( - ... model=model, - ... resource_args=dict( - ... input_shape=(1, 3, 64, 64), - ... disabled_counters=['CustomModuleCounter'])) + >>> flops_params_cfg = dict(input_shape=(1, 3, 64, 64), + ... disabled_counters=['CustomModuleCounter']) + >>> estimator.estimate(model=model, flops_params_cfg) {'flops': 0.0, 'params': 0.0, 'latency': 0.0} >>> # calculate resources of mmrazor.models @@ -69,87 +70,146 @@ class ResourceEstimator(BaseEstimator): mmrazor.engine.hooks.estimate_resources_hook for details. """ - def __init__(self, - default_shape: Tuple = (1, 3, 224, 224), - units: str = 'M', - disabled_counters: List[str] = [], - as_strings: bool = False, - measure_inference: bool = False): - super().__init__(default_shape, units, disabled_counters, as_strings, - measure_inference) - - def estimate( - self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() - ) -> Dict[str, Any]: + def __init__( + self, + input_shape: Tuple = (1, 3, 224, 224), + units: Dict = dict(flops='M', params='M', latency='ms'), + as_strings: bool = False, + flops_params_cfg: Optional[dict] = None, + latency_cfg: Optional[dict] = None, + ): + super().__init__(input_shape, units, as_strings) + if not isinstance(units, dict): + raise TypeError('units for estimator should be a dict', + f'but got `{type(units)}`') + for unit_key in units: + if unit_key not in ['flops', 'params', 'latency']: + raise KeyError(f'Got invalid key `{unit_key}` in units. ', + 'Should be `flops`, `params` or `latency`.') + if flops_params_cfg: + self.flops_params_cfg = flops_params_cfg + else: + self.flops_params_cfg = dict() + self.latency_cfg = latency_cfg if latency_cfg else dict() + + def estimate(self, + model: torch.nn.Module, + flops_params_cfg: dict = None, + latency_cfg: dict = None) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. + This method will first parse the merged :attr:`self.flops_params_cfg` + and the :attr:`self.latency_cfg` to check whether the keys are valid. + Args: model: The measured model. - resource_args (Dict[str, float]): Args for resources estimation. - NOTE: resource_args have the same items() as the init cfgs. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. + latency_cfg (dict): Cfg for estimating latency. Default to None. + + NOTE: If the `flops_params_cfg` and `latency_cfg` are both None, + this method will only estimate FLOPs/params with default settings. Returns: - Dict[str, str]): A dict that containing resource results(flops, - params and latency). + Dict[str, Union[float, str]]): A dict that contains the resource + results(FLOPs, params and latency). """ resource_metrics = dict() - if is_main_process(): - measure_inference = resource_args.pop('measure_inference', False) - if 'input_shape' not in resource_args.keys(): - resource_args['input_shape'] = self.default_shape - if 'disabled_counters' not in resource_args.keys(): - resource_args['disabled_counters'] = self.disabled_counters - model.eval() - flops, params = get_model_complexity_info(model, **resource_args) - if measure_inference: - latency = repeat_measure_inference_speed( - model, resource_args, max_iter=100, repeat_num=2) - else: - latency = 0.0 - as_strings = resource_args.get('as_strings', self.as_strings) - if as_strings and self.units is not None: - raise ValueError('Set units to None, when as_trings=True.') - if self.units is not None: - flops = params_units_convert(flops, self.units) - params = params_units_convert(params, self.units) - resource_metrics.update({ - 'flops': flops, - 'params': params, - 'latency': latency - }) - results = [resource_metrics] + measure_latency = True if latency_cfg else False + + if flops_params_cfg: + flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg} + self._check_flops_params_cfg(flops_params_cfg) + flops_params_cfg = self._set_default_resource_params( + flops_params_cfg) else: - results = [None] # type: ignore + flops_params_cfg = self.flops_params_cfg - broadcast_object_list(results) + if latency_cfg: + latency_cfg = {**self.latency_cfg, **latency_cfg} + self._check_latency_cfg(latency_cfg) + latency_cfg = self._set_default_resource_params(latency_cfg) + else: + latency_cfg = self.latency_cfg + + model.eval() + flops, params = get_model_flops_params(model, **flops_params_cfg) + if measure_latency: + latency = get_model_latency(model, **latency_cfg) + else: + latency = '0.0 ms' if self.as_strings else 0.0 # type: ignore - return results[0] + resource_metrics.update({ + 'flops': flops, + 'params': params, + 'latency': latency + }) + return resource_metrics - def estimate_spec_modules( - self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() - ) -> Dict[str, float]: - """Estimate the resources(flops/params/latency) of the spec modules. + def estimate_separation_modules( + self, + model: torch.nn.Module, + flops_params_cfg: dict = None) -> Dict[str, Union[float, str]]: + """Estimate FLOPs and params of the spec modules with separate return. Args: model: The measured model. - resource_args (Dict[str, float]): Args for resources estimation. - NOTE: resource_args have the same items() as the init cfgs. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. Returns: - Dict[str, float]): A dict that containing resource results(flops, - params) of each modules in resource_args['spec_modules']. + Dict[str, Union[float, str]]): A dict that contains the FLOPs and + params results (string | float format) of each modules in the + ``flops_params_cfg['spec_modules']``. """ - assert 'spec_modules' in resource_args, \ - 'spec_modules is required when calling estimate_spec_modules().' + if flops_params_cfg: + flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg} + self._check_flops_params_cfg(flops_params_cfg) + flops_params_cfg = self._set_default_resource_params( + flops_params_cfg) + else: + flops_params_cfg = self.flops_params_cfg + flops_params_cfg['seperate_return'] = True - resource_args.pop('measure_inference', False) - if 'input_shape' not in resource_args.keys(): - resource_args['input_shape'] = self.default_shape - if 'disabled_counters' not in resource_args.keys(): - resource_args['disabled_counters'] = self.disabled_counters + assert len(flops_params_cfg['spec_modules']), ( + 'spec_modules can not be empty when calling ' + f'`estimate_separation_modules` of {self.__class__.__name__} ') model.eval() - spec_modules_resources = get_model_complexity_info( - model, **resource_args) - + spec_modules_resources = get_model_flops_params( + model, **flops_params_cfg) return spec_modules_resources + + def _check_flops_params_cfg(self, flops_params_cfg: dict) -> None: + """Check the legality of ``flops_params_cfg``. + + Args: + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + """ + for key in flops_params_cfg: + if key not in get_model_flops_params.__code__.co_varnames[ + 1:]: # type: ignore + raise KeyError(f'Got invalid key `{key}` in flops_params_cfg.') + + def _check_latency_cfg(self, latency_cfg: dict) -> None: + """Check the legality of ``latency_cfg``. + + Args: + latency_cfg (dict): Cfg for estimating latency. + """ + for key in latency_cfg: + if key not in get_model_latency.__code__.co_varnames[ + 1:]: # type: ignore + raise KeyError(f'Got invalid key `{key}` in latency_cfg.') + + def _set_default_resource_params(self, cfg: dict) -> dict: + """Set default attributes for the input cfgs. + + Args: + cfg (dict): flops_params_cfg or latency_cfg. + """ + default_common_settings = ['input_shape', 'units', 'as_strings'] + for key in default_common_settings: + if key not in cfg: + cfg[key] = getattr(self, key) + return cfg diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py b/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py index 85f25eaff..5365831b9 100644 --- a/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py +++ b/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py @@ -1,9 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmdet.models import BaseDetector from mmrazor.registry import TASK_UTILS +try: + from mmdet.models.detectors import BaseDetector +except ImportError: + from mmrazor.utils import get_placeholder + BaseDetector = get_placeholder('mmdet') + # todo: adapt to mmdet 2.0 @TASK_UTILS.register_module() diff --git a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py index 99be89bca..60bcef4ba 100644 --- a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py +++ b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py @@ -118,9 +118,9 @@ def sample_choice(self, model: Module) -> None: def test_estimate(self) -> None: fool_conv2d = FoolConv2d() + flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) results = estimator.estimate( - model=fool_conv2d, - resource_args=dict(input_shape=(1, 3, 224, 224))) + model=fool_conv2d, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] @@ -129,9 +129,9 @@ def test_estimate(self) -> None: def test_register_module(self) -> None: fool_add_constant = FoolConvModule() + flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) results = estimator.estimate( - model=fool_add_constant, - resource_args=dict(input_shape=(1, 3, 224, 224))) + model=fool_add_constant, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] @@ -140,46 +140,65 @@ def test_register_module(self) -> None: def test_disable_sepc_counter(self) -> None: fool_add_constant = FoolConvModule() + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), + disabled_counters=['FoolAddConstantCounter']) rest_results = estimator.estimate( - model=fool_add_constant, - resource_args=dict( - input_shape=(1, 3, 224, 224), - disabled_counters=['FoolAddConstantCounter'])) + model=fool_add_constant, flops_params_cfg=flops_params_cfg) rest_flops_count = rest_results['flops'] rest_params_count = rest_results['params'] self.assertLess(rest_flops_count, 45.158) self.assertLess(rest_params_count, 0.701) - def test_estimate_spec_modules(self) -> None: + def test_estimate_spec_module(self) -> None: fool_add_constant = FoolConvModule() - results = estimator.estimate_spec_modules( - model=fool_add_constant, - resource_args=dict( - input_shape=(1, 3, 224, 224), spec_modules=['add_constant'])) + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), + spec_modules=['add_constant', 'conv2d']) + results = estimator.estimate( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) + flops_count = results['flops'] + params_count = results['params'] + + self.assertEqual(flops_count, 45.158) + self.assertEqual(params_count, 0.701) + + def test_estimate_separation_modules(self) -> None: + fool_add_constant = FoolConvModule() + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), spec_modules=['add_constant']) + results = estimator.estimate_separation_modules( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) self.assertGreater(results['add_constant']['flops'], 0) with pytest.raises(AssertionError): - results = estimator.estimate_spec_modules( - model=fool_add_constant, - resource_args=dict( - input_shape=(1, 3, 224, 224), spec_modules=['backbone'])) + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), spec_modules=['backbone']) + results = estimator.estimate_separation_modules( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) + + with pytest.raises(AssertionError): + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), spec_modules=[]) + results = estimator.estimate_separation_modules( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) def test_estimate_subnet(self) -> None: - resource_args = dict(input_shape=(1, 3, 224, 224)) + flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) model = MODELS.build(BACKBONE_CFG) self.sample_choice(model) copied_model = copy.deepcopy(model) results = estimator.estimate( - model=copied_model, resource_args=resource_args) + model=copied_model, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] fix_subnet = export_fix_subnet(model) load_fix_subnet(copied_model, fix_subnet) subnet_results = estimator.estimate( - model=copied_model, resource_args=resource_args) + model=copied_model, flops_params_cfg=flops_params_cfg) subnet_flops_count = subnet_results['flops'] subnet_params_count = subnet_results['params'] @@ -188,8 +207,8 @@ def test_estimate_subnet(self) -> None: # test whether subnet estimate will affect original model copied_model = copy.deepcopy(model) - results_after_estimate = \ - estimator.estimate(model=copied_model, resource_args=resource_args) + results_after_estimate = estimator.estimate( + model=copied_model, flops_params_cfg=flops_params_cfg) flops_count_after_estimate = results_after_estimate['flops'] params_count_after_estimate = results_after_estimate['params'] diff --git a/tests/test_runners/test_evolution_search_loop.py b/tests/test_runners/test_evolution_search_loop.py index 14e642c57..f30019274 100644 --- a/tests/test_runners/test_evolution_search_loop.py +++ b/tests/test_runners/test_evolution_search_loop.py @@ -112,10 +112,7 @@ def test_init(self): self.assertEqual(loop.candidates, fake_candidates) @patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet') - @patch( - 'mmrazor.engine.runner.evolution_search_loop.get_model_complexity_info' - ) - def test_run_epoch(self, mock_flops, mock_export_fix_subnet): + def test_run_epoch(self, mock_export_fix_subnet): # test_run_epoch: distributed == False loop_cfg = copy.deepcopy(self.train_cfg) loop_cfg.runner = self.runner @@ -155,7 +152,7 @@ def test_run_epoch(self, mock_flops, mock_export_fix_subnet): self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} loop.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = (50., 1) + loop._check_constraints = MagicMock(return_value=True) mock_export_fix_subnet.return_value = fake_subnet loop.run_epoch() self.assertEqual(len(loop.candidates), 4) diff --git a/tests/test_runners/test_subnet_sampler_loop.py b/tests/test_runners/test_subnet_sampler_loop.py index 0f26c5aeb..fca29b823 100644 --- a/tests/test_runners/test_subnet_sampler_loop.py +++ b/tests/test_runners/test_subnet_sampler_loop.py @@ -192,30 +192,15 @@ def test_sample_subnet(self): self.assertEqual(subnet, fake_subnet) self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1) - @patch('mmrazor.engine.runner.subnet_sampler_loop.export_fix_subnet') - @patch( - 'mmrazor.engine.runner.subnet_sampler_loop.get_model_complexity_info') - def test_run(self, mock_flops, mock_export_fix_subnet): - # test run with flops_range=None - cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_run1' - runner = Runner.from_cfg(cfg) - fake_subnet = {'1': 'choice1', '2': 'choice2'} - runner.model.sample_subnet = MagicMock(return_value=fake_subnet) - runner.train() - - self.assertEqual(runner.iter, runner.max_iters) - assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl')) - + def test_run(self): # test run with _check_constraints cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_run2' - cfg.train_cfg.flops_range = (0, 100) + cfg.experiment_name = 'test_run1' runner = Runner.from_cfg(cfg) fake_subnet = {'1': 'choice1', '2': 'choice2'} runner.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = (50., 1) - mock_export_fix_subnet.return_value = fake_subnet + loop = runner.build_train_loop(cfg.train_cfg) + loop._check_constraints = MagicMock(return_value=True) runner.train() self.assertEqual(runner.iter, runner.max_iters) diff --git a/tests/test_runners/test_utils/test_check.py b/tests/test_runners/test_utils/test_check.py new file mode 100644 index 000000000..b9bd57989 --- /dev/null +++ b/tests/test_runners/test_utils/test_check.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +from mmrazor.engine.runner.utils import check_subnet_flops + +try: + from mmdet.models.detectors import BaseDetector +except ImportError: + from mmrazor.utils import get_placeholder + BaseDetector = get_placeholder('mmdet') + + +@patch('mmrazor.models.ResourceEstimator') +@patch('mmrazor.models.SPOS') +def test_check_subnet_flops(mock_model, mock_estimator): + # flops_range = None + flops_range = None + fake_subnet = {'1': 'choice1', '2': 'choice2'} + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, + flops_range) + assert result is True + + # flops_range is not None + # architecturte is BaseDetector + flops_range = (0., 100.) + mock_model.architecture = BaseDetector + fake_results = {'flops': 50.} + mock_estimator.estimate.return_value = fake_results + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, + flops_range) + assert result is True + + # flops_range is not None + # architecturte is BaseDetector + flops_range = (0., 100.) + fake_results = {'flops': -50.} + mock_estimator.estimate.return_value = fake_results + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, + flops_range) + assert result is False