From 1ae7f5da2932a1eefc44a2c3336b43bee2a92793 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Mon, 5 Sep 2022 12:29:41 +0800 Subject: [PATCH 1/7] update estimator usage and fix bugs --- .../engine/runner/evolution_search_loop.py | 15 +- mmrazor/engine/runner/subnet_sampler_loop.py | 15 +- .../task_modules/estimators/base_estimator.py | 35 ++-- .../counters/flops_params_counter.py | 67 +++++--- .../estimators/counters/latency_counter.py | 77 +++++---- .../estimators/resource_estimator.py | 161 ++++++++++-------- .../test_estimators/test_flops_params.py | 51 ++++-- .../test_evolution_search_loop.py | 6 +- .../test_runners/test_subnet_sampler_loop.py | 5 +- 9 files changed, 246 insertions(+), 186 deletions(-) diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index a72704e40..faa564956 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -14,7 +14,7 @@ from mmengine.utils import is_list_of from torch.utils.data import DataLoader -from mmrazor.models.task_modules.estimators import get_model_complexity_info +from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet from mmrazor.utils import SupportRandomSubnet @@ -44,6 +44,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): mutate_prob (float): The probability of mutation. Defaults to 0.1. flops_range (tuple, optional): flops_range to be used for screening candidates. + resource_input_shape (Tuple): Input shape when measuring flops. + Default to (1, 3, 224, 224). spec_modules (list): Used for specify modules need to counter. Defaults to list(). score_key (str): Specify one metric in evaluation results to score @@ -65,7 +67,8 @@ def __init__(self, num_mutation: int = 25, num_crossover: int = 25, mutate_prob: float = 0.1, - flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6), + flops_range: Optional[Tuple[float, float]] = (0., 330), + resource_input_shape: Tuple = (1, 3, 224, 224), spec_modules: List = [], score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: @@ -101,6 +104,7 @@ def __init__(self, correct init candidates file' self.top_k_candidates = Candidates() + self.estimator = ResourceEstimator(input_shape=resource_input_shape) if self.runner.distributed: self.model = runner.model.module @@ -306,10 +310,11 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: fix_mutable = export_fix_subnet(self.model) copied_model = copy.deepcopy(self.model) load_fix_subnet(copied_model, fix_mutable) - flops, _ = get_model_complexity_info( - copied_model, spec_modules=self.spec_modules) + results = self.estimator.estimate( + copied_model, spec_modules=self.spec_modules, as_strings=False) + flops = results['flops'] - if self.flops_range[0] <= flops <= self.flops_range[1]: + if self.flops_range[0] <= flops <= self.flops_range[1]: # type: ignore return True else: return False diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index c2b4d2176..eed31b9e5 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -13,7 +13,7 @@ from mmengine.utils import is_list_of from torch.utils.data import DataLoader -from mmrazor.models.task_modules.estimators import get_model_complexity_info +from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet from mmrazor.utils import SupportRandomSubnet @@ -103,6 +103,8 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. flops_range (dict): Constraints to be used for screening candidates. + resource_input_shape (Tuple): Input shape when measuring flops. + Default to (1, 3, 224, 224). spec_modules (list): Used for specify modules need to counter. Defaults to list(). num_candidates (int): The number of the candidates consist of samples @@ -138,7 +140,8 @@ def __init__(self, val_begin: int = 1, val_interval: int = 1000, score_key: str = 'accuracy/top1', - flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6), + flops_range: Optional[Tuple[float, float]] = (0., 330), + resource_input_shape: Tuple = (1, 3, 224, 224), spec_modules: List = [], num_candidates: int = 1000, num_samples: int = 10, @@ -177,6 +180,7 @@ def __init__(self, self.candidates = Candidates() self.top_k_candidates = Candidates() + self.estimator = ResourceEstimator(input_shape=resource_input_shape) def run(self) -> None: """Launch training.""" @@ -324,10 +328,11 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: fix_mutable = export_fix_subnet(self.model) copied_model = copy.deepcopy(self.model) load_fix_subnet(copied_model, fix_mutable) - flops, _ = get_model_complexity_info( - copied_model, spec_modules=self.spec_modules) + results = self.estimator.estimate( + copied_model, spec_modules=self.spec_modules, as_strings=False) + flops = results['flops'] - if self.flops_range[0] <= flops <= self.flops_range[1]: + if self.flops_range[0] <= flops <= self.flops_range[1]: # type: ignore return True else: return False diff --git a/mmrazor/models/task_modules/estimators/base_estimator.py b/mmrazor/models/task_modules/estimators/base_estimator.py index 22a82d105..497b19fe0 100644 --- a/mmrazor/models/task_modules/estimators/base_estimator.py +++ b/mmrazor/models/task_modules/estimators/base_estimator.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Tuple +from typing import Dict, Tuple, Union import torch.nn @@ -12,44 +12,33 @@ class BaseEstimator(metaclass=ABCMeta): """The base class of Estimator, used for estimating model infos. Args: - default_shape (tuple): Input data's default shape, for calculating + input_shape (tuple): Input data's default shape, for calculating resources consume. Defaults to (1, 3, 224, 224). - units (str): Resource units. Defaults to 'M'. - disabled_counters (list): List of disabled spec op counters. - Defaults to None. + units (dict): A dict including required units. Default to dict(). as_strings (bool): Output FLOPs and params counts in a string form. Default to False. - measure_inference (bool): whether to measure infer speed or not. - Default to False. """ def __init__(self, - default_shape: Tuple = (1, 3, 224, 224), - units: str = 'M', - disabled_counters: List[str] = None, - as_strings: bool = False, - measure_inference: bool = False): - assert len(default_shape) in [3, 4, 5], \ - f'Unsupported shape: {default_shape}' - self.default_shape = default_shape + input_shape: Tuple = (1, 3, 224, 224), + units: Dict = dict(), + as_strings: bool = False): + assert len(input_shape) == 4, ( + f'The length of input_shape must be 4. Got {len(input_shape)}.') + self.input_shape = input_shape self.units = units - self.disabled_counters = disabled_counters self.as_strings = as_strings - self.measure_inference = measure_inference @abstractmethod - def estimate( - self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() - ) -> Dict[str, float]: + def estimate(self, model: torch.nn.Module, + **kwargs) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. Args: model: The measured model. - resource_args (Dict[str, float]): resources information. - NOTE: resource_args have the same items() as the init cfgs. Returns: - Dict[str, float]): A dict that containing resource results(flops, + Dict[str, float]): A dict that contains resource results(flops, params and latency). """ pass diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index 31e998a2a..f666b9771 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import sys from functools import partial +from typing import Dict import torch import torch.nn as nn @@ -13,7 +14,9 @@ def get_model_complexity_info(model, spec_modules=[], disabled_counters=[], print_per_layer_stat=False, + units=dict(flops='M', params='M'), as_strings=False, + seperate_return: bool = False, input_constructor=None, flush=False, ost=sys.stdout): @@ -39,7 +42,7 @@ def get_model_complexity_info(model, Args: model (nn.Module): The model for complexity calculation. input_shape (tuple): Input shape (including batchsize) used for - calculation. Default to (1, 3, 224, 224) + calculation. Default to (1, 3, 224, 224). spec_modules (list): A list that contains the names of several spec modules, which users want to get resources infos of them. e.g., ['backbone', 'head'], ['backbone.layer1']. Default to []. @@ -47,6 +50,8 @@ def get_model_complexity_info(model, calculated. Default to []. print_per_layer_stat (bool): Whether to print complexity information for each layer in a model. Default to True. + units (dict): A dict including converted FLOPs and params units. + Default to dict(flops='M', params='M'). as_strings (bool): Output FLOPs and params counts in a string form. Default to True. input_constructor (None | callable): If specified, it takes a callable @@ -60,12 +65,20 @@ def get_model_complexity_info(model, tuple[float | str] | dict[str, float]: If `as_strings` is set to True, it will return FLOPs and parameter counts in a string format. Otherwise, it will return those in a float number format. - If len(spec_modules) > 0, it will return a resource info dict with - FLOPs and parameter counts of each spec module in float format. + NOTE: If seperate_return, it will return a resource info dict with + FLOPs & params counts of each spec module in float|string format. """ assert type(input_shape) is tuple assert len(input_shape) >= 1 assert isinstance(model, nn.Module) + if seperate_return and not len(spec_modules): + raise AssertionError(f'seperate_return can only be set to True when ' + f'spec_modules are not empty. Got spec_modules=' + f'{spec_modules}.') + if as_strings: + flops_suffix = ' ' + units['flops'] + 'FLOPs' if units else ' FLOPs' + params_suffix = ' ' + units['params'] if units else '' + flops_params_model = add_flops_params_counting_methods(model) flops_params_model.eval() flops_params_model.start_flops_params_count(disabled_counters) @@ -96,34 +109,40 @@ def get_model_complexity_info(model, ost=ost, flush=flush) + if units is not None: + flops_count = params_units_convert(flops_count, units['flops']) + params_count = params_units_convert(params_count, units['params']) + if len(spec_modules): + flops_count, params_count = 0.0, 0.0 module_names = [name for name, _ in flops_params_model.named_modules()] for module in spec_modules: assert module in module_names, \ f'All modules in spec_modules should be in the measured ' \ f'flops_params_model. Got module {module} in spec_modules.' - spec_modules_resources = dict() - accumulate_sub_module_flops_params(flops_params_model) + spec_modules_resources: Dict[str, dict] = dict() + accumulate_sub_module_flops_params(flops_params_model, units=units) for name, module in flops_params_model.named_modules(): if name in spec_modules: spec_modules_resources[name] = dict() spec_modules_resources[name]['flops'] = module.__flops__ spec_modules_resources[name]['params'] = module.__params__ + flops_count += module.__flops__ + params_count += module.__params__ if as_strings: - spec_modules_resources[name]['flops'] = str( - params_units_convert(module.__flops__, - 'G')) + ' GFLOPs' - spec_modules_resources[name]['params'] = str( - params_units_convert(module.__params__, 'M')) + ' M' + spec_modules_resources[name]['flops'] = \ + str(module.__flops__) + flops_suffix + spec_modules_resources[name]['params'] = \ + str(module.__params__) + params_suffix flops_params_model.stop_flops_params_count() - if len(spec_modules): + if seperate_return: return spec_modules_resources if as_strings: - flops_string = str(params_units_convert(flops_count, 'G')) + ' GFLOPs' - params_string = str(params_units_convert(params_count, 'M')) + ' M' + flops_string = str(flops_count) + flops_suffix + params_string = str(params_count) + params_suffix return flops_string, params_string return flops_count, params_count @@ -164,7 +183,7 @@ def params_units_convert(num_params, units='M', precision=3): def print_model_with_flops_params(model, total_flops, total_params, - units='G', + units=dict(flops='M', params='M'), precision=3, ost=sys.stdout, flush=False): @@ -174,7 +193,9 @@ def print_model_with_flops_params(model, model (nn.Module): The model to be printed. total_flops (float): Total FLOPs of the model. total_params (float): Total parameter counts of the model. - units (str | None): Converted FLOPs units. Default to 'G'. + units (tuple | none): A tuple pair including converted FLOPs & params + units. e.g., ('G', 'M') stands for FLOPs as 'G' & params as 'M'. + Default to ('M', 'M'). precision (int): Digit number after the decimal point. Default to 3. ost (stream): same as `file` param in :func:`print`. Default to sys.stdout. @@ -241,11 +262,11 @@ def flops_repr(self): accumulated_flops_cost = self.accumulate_flops() flops_string = str( params_units_convert( - accumulated_flops_cost, units=units, - precision=precision)) + ' ' + units + 'FLOPs' + accumulated_flops_cost, units['flops'], + precision=precision)) + ' ' + units['flops'] + 'FLOPs' params_string = str( - params_units_convert( - accumulated_num_params, units='M', precision=precision)) + ' M' + params_units_convert(accumulated_num_params, units['params'], + precision)) + ' M' return ', '.join([ params_string, '{:.3%} Params'.format(accumulated_num_params / total_params), @@ -277,12 +298,15 @@ def del_extra_repr(m): model.apply(del_extra_repr) -def accumulate_sub_module_flops_params(model): +def accumulate_sub_module_flops_params(model, units=None): """Accumulate FLOPs and params for each module in the model. Each module in the model will have the `__flops__` and `__params__` parameters. Args: model (nn.Module): The model to be accumulated. + units (tuple | none): A tuple pair including converted FLOPs & params + units. e.g., ('G', 'M') stands for FLOPs as 'G' & params as 'M'. + Default to None. """ def accumulate_params(module): @@ -310,6 +334,9 @@ def accumulate_flops(module): _params = accumulate_params(module) module.__flops__ = _flops module.__params__ = _params + if units is not None: + module.__flops__ = params_units_convert(_flops, units['flops']) + module.__params__ = params_units_convert(_params, units['params']) def get_model_parameters_number(model): diff --git a/mmrazor/models/task_modules/estimators/counters/latency_counter.py b/mmrazor/models/task_modules/estimators/counters/latency_counter.py index e3e91c54e..58c888b5c 100644 --- a/mmrazor/models/task_modules/estimators/counters/latency_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/latency_counter.py @@ -1,42 +1,49 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging import time -from typing import Any, Dict +from typing import Tuple import torch from mmengine.logging import print_log def repeat_measure_inference_speed(model: torch.nn.Module, - resource_args: Dict[str, Any], - max_iter: int = 100, - num_warmup: int = 5, - log_interval: int = 100, - repeat_num: int = 1) -> float: + input_shape: Tuple = (1, 3, 224, 224), + latency_max_iter: int = 100, + latency_num_warmup: int = 5, + latency_log_interval: int = 100, + latency_repeat_num: int = 1, + unit: str = 'ms', + as_strings: bool = False) -> float: """Repeat speed measure for multi-times to get more precise results. Args: model (torch.nn.Module): The measured model. - resource_args (Dict[str, float]): resources information. - max_iter (Optional[int]): Max iteration num for inference speed test. - num_warmup (Optional[int]): Iteration num for warm-up stage. - log_interval (Optional[int]): Interval num for logging the results. - repeat_num (Optional[int]): Num of times to repeat the measurement. + input_shape (tuple): Input shape (including batchsize) used for + calculation. Default to (1, 3, 224, 224). + latency_max_iter (Optional[int]): Max iteration num for the + measurement. Default to 100. + latency_num_warmup (Optional[int]): Iteration num for warm-up stage. + Default to 5. + latency_log_interval (Optional[int]): Interval num for logging the + results. Default to 100. + latency_repeat_num (Optional[int]): Num of times to repeat the + measurement. Default to 1. Returns: fps (float): The measured inference speed of the model. """ - assert repeat_num >= 1 + assert latency_repeat_num >= 1 fps_list = [] - for _ in range(repeat_num): + for _ in range(latency_repeat_num): fps_list.append( - measure_inference_speed(model, resource_args, max_iter, num_warmup, - log_interval)) + measure_inference_speed(model, input_shape, latency_max_iter, + latency_num_warmup, latency_log_interval)) - if repeat_num > 1: + if latency_repeat_num > 1: fps_list_ = [round(fps, 1) for fps in fps_list] times_per_img_list = [round(1000 / fps, 1) for fps in fps_list] mean_fps_ = sum(fps_list_) / len(fps_list_) @@ -54,18 +61,22 @@ def repeat_measure_inference_speed(model: torch.nn.Module, def measure_inference_speed(model: torch.nn.Module, - resource_args: Dict[str, Any], - max_iter: int = 100, - num_warmup: int = 5, - log_interval: int = 100) -> float: + input_shape: Tuple = (1, 3, 224, 224), + latency_max_iter: int = 100, + latency_num_warmup: int = 5, + latency_log_interval: int = 100) -> float: """Measure inference speed on GPU devices. Args: model (torch.nn.Module): The measured model. - resource_args (Dict[str, float]): resources information. - max_iter (Optional[int]): Max iteration num for inference speed test. - num_warmup (Optional[int]): Iteration num for warm-up stage. - log_interval (Optional[int]): Interval num for logging the results. + input_shape (tuple): Input shape (including batchsize) used for + calculation. Default to (1, 3, 224, 224). + latency_max_iter (Optional[int]): Max iteration num for the + measurement. Default to 100. + latency_num_warmup (Optional[int]): Iteration num for warm-up stage. + Default to 5. + latency_log_interval (Optional[int]): Interval num for logging the + results. Default to 100. Returns: fps (float): The measured inference speed of the model. @@ -78,10 +89,10 @@ def measure_inference_speed(model: torch.nn.Module, device = 'cuda' else: raise NotImplementedError('To use cpu to test latency not supported.') - # benchmark with {max_iter} image and take the average - for i in range(1, max_iter): + # benchmark with {latency_max_iter} image and take the average + for i in range(1, latency_max_iter): if device == 'cuda': - data = torch.rand(resource_args['input_shape']).cuda() + data = torch.rand(input_shape).cuda() torch.cuda.synchronize() start_time = time.perf_counter() @@ -91,19 +102,19 @@ def measure_inference_speed(model: torch.nn.Module, torch.cuda.synchronize() elapsed = time.perf_counter() - start_time - if i >= num_warmup: + if i >= latency_num_warmup: pure_inf_time += elapsed - if (i + 1) % log_interval == 0: - fps = (i + 1 - num_warmup) / pure_inf_time + if (i + 1) % latency_log_interval == 0: + fps = (i + 1 - latency_num_warmup) / pure_inf_time print_log( - f'Done image [{i + 1:<3}/ {max_iter}], ' + f'Done image [{i + 1:<3}/ {latency_max_iter}], ' f'fps: {fps:.1f} img / s, ' f'times per image: {1000 / fps:.1f} ms / img', logger='current', level=logging.DEBUG) - if (i + 1) == max_iter: - fps = (i + 1 - num_warmup) / pure_inf_time + if (i + 1) == latency_max_iter: + fps = (i + 1 - latency_num_warmup) / pure_inf_time print_log( f'Overall fps: {fps:.1f} img / s, ' f'times per image: {1000 / fps:.1f} ms / img', diff --git a/mmrazor/models/task_modules/estimators/resource_estimator.py b/mmrazor/models/task_modules/estimators/resource_estimator.py index 6d4342866..eaf5501f6 100644 --- a/mmrazor/models/task_modules/estimators/resource_estimator.py +++ b/mmrazor/models/task_modules/estimators/resource_estimator.py @@ -1,13 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Tuple, Union import torch.nn -from mmengine.dist import broadcast_object_list, is_main_process from mmrazor.registry import TASK_UTILS from .base_estimator import BaseEstimator -from .counters import (get_model_complexity_info, params_units_convert, - repeat_measure_inference_speed) +from .counters import get_model_complexity_info, repeat_measure_inference_speed @TASK_UTILS.register_module() @@ -15,14 +13,28 @@ class ResourceEstimator(BaseEstimator): """Estimator for calculating the resources consume. Args: - default_shape (tuple): Input data's default shape, for calculating - resources consume. Defaults to (1, 3, 224, 224) - units (str): Resource units. Defaults to 'M'. + input_shape (tuple): Input data's default shape, for calculating + resources consume. Defaults to (1, 3, 224, 224). + units (dict): Dict that contains converted FLOPs/params/latency units. + Default to dict(flops='M', params='M', latency='ms'). + as_strings (bool): Output FLOPs/params/latency counts in a string + form. Default to False. + spec_modules (list): List of spec modules that needed to count. + e.g., ['backbone', 'head'], ['backbone.layer1']. Default to []. disabled_counters (list): List of disabled spec op counters. - Defaults to None. - NOTE: disabled_counters contains the op counter class names - in estimator.op_counters that require to be disabled, - such as 'ConvCounter', 'BatchNorm2dCounter', ... + It contains the op counter names in estimator.op_counters that + are required to be disabled, e.g., ['BatchNorm2dCounter']. + Defaults to []. + measure_latency (bool): whether to measure inference speed or not. + Default to False. + latency_max_iter (Optional[int]): Max iteration num for the + measurement. Default to 100. + latency_num_warmup (Optional[int]): Iteration num for warm-up stage. + Default to 5. + latency_log_interval (Optional[int]): Interval num for logging the + results. Default to 100. + latency_repeat_num (Optional[int]): Num of times to repeat the + measurement. Default to 1. Examples: >>> # direct calculate resource consume of nn.Conv2d @@ -30,7 +42,7 @@ class ResourceEstimator(BaseEstimator): >>> estimator = ResourceEstimator() >>> estimator.estimate( ... model=conv2d, - ... resource_args=dict(input_shape=(1, 3, 64, 64))) + ... input_shape=(1, 3, 64, 64)) {'flops': 3.444, 'params': 0.001, 'latency': 0.0} >>> # calculate resources of custom modules @@ -53,15 +65,14 @@ class ResourceEstimator(BaseEstimator): >>> model = CustomModule() >>> estimator.estimate( ... model=model, - ... resource_args=dict(input_shape=(1, 3, 64, 64))) + ... input_shape=(1, 3, 64, 64)) {'flops': 1.0, 'params': 0.7, 'latency': 0.0} ... >>> # calculate resources of custom modules with disable_counters >>> estimator.estimate( ... model=model, - ... resource_args=dict( - ... input_shape=(1, 3, 64, 64), - ... disabled_counters=['CustomModuleCounter'])) + ... input_shape=(1, 3, 64, 64), + ... disabled_counters=['CustomModuleCounter']) {'flops': 0.0, 'params': 0.0, 'latency': 0.0} >>> # calculate resources of mmrazor.models @@ -69,87 +80,87 @@ class ResourceEstimator(BaseEstimator): mmrazor.engine.hooks.estimate_resources_hook for details. """ - def __init__(self, - default_shape: Tuple = (1, 3, 224, 224), - units: str = 'M', - disabled_counters: List[str] = [], - as_strings: bool = False, - measure_inference: bool = False): - super().__init__(default_shape, units, disabled_counters, as_strings, - measure_inference) - - def estimate( - self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() - ) -> Dict[str, Any]: + def __init__( + self, + input_shape: Tuple = (1, 3, 224, 224), + units: Dict = dict(flops='M', params='M', latency='ms'), + as_strings: bool = False, + spec_modules: List[str] = [], + disabled_counters: List[str] = [], + measure_latency: bool = False, + latency_max_iter: int = 100, + latency_num_warmup: int = 5, + latency_log_interval: int = 100, + latency_repeat_num: int = 1, + ): + super().__init__(input_shape, units, as_strings) + self.spec_modules = spec_modules + self.disabled_counters = disabled_counters + + self.measure_latency = measure_latency + self.latency_max_iter = latency_max_iter + self.latency_num_warmup = latency_num_warmup + self.latency_log_interval = latency_log_interval + self.latency_repeat_num = latency_repeat_num + + def estimate(self, model: torch.nn.Module, + **kwargs) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. Args: model: The measured model. - resource_args (Dict[str, float]): Args for resources estimation. - NOTE: resource_args have the same items() as the init cfgs. Returns: Dict[str, str]): A dict that containing resource results(flops, params and latency). """ + latency_cfg = dict() resource_metrics = dict() - if is_main_process(): - measure_inference = resource_args.pop('measure_inference', False) - if 'input_shape' not in resource_args.keys(): - resource_args['input_shape'] = self.default_shape - if 'disabled_counters' not in resource_args.keys(): - resource_args['disabled_counters'] = self.disabled_counters - model.eval() - flops, params = get_model_complexity_info(model, **resource_args) - if measure_inference: - latency = repeat_measure_inference_speed( - model, resource_args, max_iter=100, repeat_num=2) - else: - latency = 0.0 - as_strings = resource_args.get('as_strings', self.as_strings) - if as_strings and self.units is not None: - raise ValueError('Set units to None, when as_trings=True.') - if self.units is not None: - flops = params_units_convert(flops, self.units) - params = params_units_convert(params, self.units) - resource_metrics.update({ - 'flops': flops, - 'params': params, - 'latency': latency - }) - results = [resource_metrics] - else: - results = [None] # type: ignore + for key in self.__init__.__code__.co_varnames[1:]: # type: ignore + kwargs.setdefault(key, getattr(self, key)) + if 'latency' in key: + latency_cfg[key] = kwargs.pop(key) + latency_cfg['unit'] = kwargs['units'].get('latency') + latency_cfg['as_strings'] = kwargs['as_strings'] + latency_cfg['input_shape'] = kwargs['input_shape'] - broadcast_object_list(results) + model.eval() + flops, params = get_model_complexity_info(model, **kwargs) - return results[0] + if latency_cfg['measure_latency']: + latency = repeat_measure_inference_speed(model, **latency_cfg) + else: + latency = '0.0 ms' if kwargs['as_strings'] else 0.0 # type: ignore - def estimate_spec_modules( - self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() - ) -> Dict[str, float]: + resource_metrics.update({ + 'flops': flops, + 'params': params, + 'latency': latency + }) + return resource_metrics + + def estimate_separation_modules(self, model: torch.nn.Module, + **kwargs) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the spec modules. Args: model: The measured model. - resource_args (Dict[str, float]): Args for resources estimation. - NOTE: resource_args have the same items() as the init cfgs. Returns: Dict[str, float]): A dict that containing resource results(flops, - params) of each modules in resource_args['spec_modules']. + params) of each modules in kwargs['spec_modules']. """ - assert 'spec_modules' in resource_args, \ - 'spec_modules is required when calling estimate_spec_modules().' + for key in self.__init__.__code__.co_varnames[1:]: # type: ignore + kwargs.setdefault(key, getattr(self, key)) + # TODO: support speed estimation for separation modules. + if 'latency' in key: + kwargs.pop(key) - resource_args.pop('measure_inference', False) - if 'input_shape' not in resource_args.keys(): - resource_args['input_shape'] = self.default_shape - if 'disabled_counters' not in resource_args.keys(): - resource_args['disabled_counters'] = self.disabled_counters + assert len(kwargs['spec_modules']), ( + f'spec_modules can not be empty when calling ' + f'{self.__class__.__name__}.estimate_separation_modules().') + kwargs['seperate_return'] = True model.eval() - spec_modules_resources = get_model_complexity_info( - model, **resource_args) - + spec_modules_resources = get_model_complexity_info(model, **kwargs) return spec_modules_resources diff --git a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py index 99be89bca..ee6ac6bd0 100644 --- a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py +++ b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py @@ -119,8 +119,7 @@ def sample_choice(self, model: Module) -> None: def test_estimate(self) -> None: fool_conv2d = FoolConv2d() results = estimator.estimate( - model=fool_conv2d, - resource_args=dict(input_shape=(1, 3, 224, 224))) + model=fool_conv2d, input_shape=(1, 3, 224, 224)) flops_count = results['flops'] params_count = results['params'] @@ -130,8 +129,7 @@ def test_estimate(self) -> None: def test_register_module(self) -> None: fool_add_constant = FoolConvModule() results = estimator.estimate( - model=fool_add_constant, - resource_args=dict(input_shape=(1, 3, 224, 224))) + model=fool_add_constant, input_shape=(1, 3, 224, 224)) flops_count = results['flops'] params_count = results['params'] @@ -142,44 +140,61 @@ def test_disable_sepc_counter(self) -> None: fool_add_constant = FoolConvModule() rest_results = estimator.estimate( model=fool_add_constant, - resource_args=dict( - input_shape=(1, 3, 224, 224), - disabled_counters=['FoolAddConstantCounter'])) + input_shape=(1, 3, 224, 224), + disabled_counters=['FoolAddConstantCounter']) rest_flops_count = rest_results['flops'] rest_params_count = rest_results['params'] self.assertLess(rest_flops_count, 45.158) self.assertLess(rest_params_count, 0.701) - def test_estimate_spec_modules(self) -> None: + def test_estimate_spec_module(self) -> None: + fool_add_constant = FoolConvModule() + results = estimator.estimate( + model=fool_add_constant, + input_shape=(1, 3, 224, 224), + spec_modules=['add_constant', 'conv2d']) + flops_count = results['flops'] + params_count = results['params'] + + self.assertEqual(flops_count, 45.158) + self.assertEqual(params_count, 0.701) + + def test_estimate_separation_modules(self) -> None: fool_add_constant = FoolConvModule() - results = estimator.estimate_spec_modules( + results = estimator.estimate_separation_modules( model=fool_add_constant, - resource_args=dict( - input_shape=(1, 3, 224, 224), spec_modules=['add_constant'])) + input_shape=(1, 3, 224, 224), + spec_modules=['add_constant']) self.assertGreater(results['add_constant']['flops'], 0) with pytest.raises(AssertionError): - results = estimator.estimate_spec_modules( + results = estimator.estimate_separation_modules( + model=fool_add_constant, + input_shape=(1, 3, 224, 224), + spec_modules=['backbone']) + + with pytest.raises(AssertionError): + results = estimator.estimate_separation_modules( model=fool_add_constant, - resource_args=dict( - input_shape=(1, 3, 224, 224), spec_modules=['backbone'])) + input_shape=(1, 3, 224, 224), + spec_modules=[]) def test_estimate_subnet(self) -> None: - resource_args = dict(input_shape=(1, 3, 224, 224)) + input_shape = (1, 3, 224, 224) model = MODELS.build(BACKBONE_CFG) self.sample_choice(model) copied_model = copy.deepcopy(model) results = estimator.estimate( - model=copied_model, resource_args=resource_args) + model=copied_model, input_shape=input_shape) flops_count = results['flops'] params_count = results['params'] fix_subnet = export_fix_subnet(model) load_fix_subnet(copied_model, fix_subnet) subnet_results = estimator.estimate( - model=copied_model, resource_args=resource_args) + model=copied_model, input_shape=input_shape) subnet_flops_count = subnet_results['flops'] subnet_params_count = subnet_results['params'] @@ -189,7 +204,7 @@ def test_estimate_subnet(self) -> None: # test whether subnet estimate will affect original model copied_model = copy.deepcopy(model) results_after_estimate = \ - estimator.estimate(model=copied_model, resource_args=resource_args) + estimator.estimate(model=copied_model, input_shape=input_shape) flops_count_after_estimate = results_after_estimate['flops'] params_count_after_estimate = results_after_estimate['params'] diff --git a/tests/test_runners/test_evolution_search_loop.py b/tests/test_runners/test_evolution_search_loop.py index 14e642c57..8655c9c4a 100644 --- a/tests/test_runners/test_evolution_search_loop.py +++ b/tests/test_runners/test_evolution_search_loop.py @@ -112,9 +112,7 @@ def test_init(self): self.assertEqual(loop.candidates, fake_candidates) @patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet') - @patch( - 'mmrazor.engine.runner.evolution_search_loop.get_model_complexity_info' - ) + @patch('mmrazor.models.task_modules.ResourceEstimator.estimate') def test_run_epoch(self, mock_flops, mock_export_fix_subnet): # test_run_epoch: distributed == False loop_cfg = copy.deepcopy(self.train_cfg) @@ -155,7 +153,7 @@ def test_run_epoch(self, mock_flops, mock_export_fix_subnet): self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} loop.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = (50., 1) + mock_flops.return_value = dict(flops=10.0, params=2.0) mock_export_fix_subnet.return_value = fake_subnet loop.run_epoch() self.assertEqual(len(loop.candidates), 4) diff --git a/tests/test_runners/test_subnet_sampler_loop.py b/tests/test_runners/test_subnet_sampler_loop.py index 0f26c5aeb..1271b864b 100644 --- a/tests/test_runners/test_subnet_sampler_loop.py +++ b/tests/test_runners/test_subnet_sampler_loop.py @@ -193,8 +193,7 @@ def test_sample_subnet(self): self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1) @patch('mmrazor.engine.runner.subnet_sampler_loop.export_fix_subnet') - @patch( - 'mmrazor.engine.runner.subnet_sampler_loop.get_model_complexity_info') + @patch('mmrazor.models.task_modules.ResourceEstimator.estimate') def test_run(self, mock_flops, mock_export_fix_subnet): # test run with flops_range=None cfg = copy.deepcopy(self.iter_based_cfg) @@ -214,7 +213,7 @@ def test_run(self, mock_flops, mock_export_fix_subnet): runner = Runner.from_cfg(cfg) fake_subnet = {'1': 'choice1', '2': 'choice2'} runner.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = (50., 1) + mock_flops.return_value = dict(flops=10.0, params=2.0) mock_export_fix_subnet.return_value = fake_subnet runner.train() From bbbb675b251d95bc445d4d711e1126948c82c211 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Thu, 8 Sep 2022 17:19:51 +0800 Subject: [PATCH 2/7] refactor api of estimator & add inner check methods --- .../engine/runner/evolution_search_loop.py | 12 +- mmrazor/engine/runner/subnet_sampler_loop.py | 12 +- .../task_modules/estimators/base_estimator.py | 19 +- .../estimators/counters/__init__.py | 10 +- .../counters/flops_params_counter.py | 50 ++--- .../estimators/counters/latency_counter.py | 87 +++++---- .../estimators/resource_estimator.py | 182 ++++++++++++------ .../test_estimators/test_flops_params.py | 44 +++-- 8 files changed, 245 insertions(+), 171 deletions(-) diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index faa564956..263fb76e1 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -46,7 +46,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): candidates. resource_input_shape (Tuple): Input shape when measuring flops. Default to (1, 3, 224, 224). - spec_modules (list): Used for specify modules need to counter. + resource_spec_modules (list): Used for specify modules need to counter. Defaults to list(). score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. @@ -69,7 +69,7 @@ def __init__(self, mutate_prob: float = 0.1, flops_range: Optional[Tuple[float, float]] = (0., 330), resource_input_shape: Tuple = (1, 3, 224, 224), - spec_modules: List = [], + resource_spec_modules: List = [], score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: super().__init__(runner, dataloader, max_epochs) @@ -88,7 +88,6 @@ def __init__(self, self.num_candidates = num_candidates self.top_k = top_k self.flops_range = flops_range - self.spec_modules = spec_modules self.score_key = score_key self.num_mutation = num_mutation self.num_crossover = num_crossover @@ -104,7 +103,9 @@ def __init__(self, correct init candidates file' self.top_k_candidates = Candidates() - self.estimator = ResourceEstimator(input_shape=resource_input_shape) + self.estimator = ResourceEstimator( + input_shape=resource_input_shape, + flops_params_cfg=dict(spec_modules=resource_spec_modules)) if self.runner.distributed: self.model = runner.model.module @@ -310,8 +311,7 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: fix_mutable = export_fix_subnet(self.model) copied_model = copy.deepcopy(self.model) load_fix_subnet(copied_model, fix_mutable) - results = self.estimator.estimate( - copied_model, spec_modules=self.spec_modules, as_strings=False) + results = self.estimator.estimate(copied_model) flops = results['flops'] if self.flops_range[0] <= flops <= self.flops_range[1]: # type: ignore diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index eed31b9e5..9beeaf2b4 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -105,7 +105,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): flops_range (dict): Constraints to be used for screening candidates. resource_input_shape (Tuple): Input shape when measuring flops. Default to (1, 3, 224, 224). - spec_modules (list): Used for specify modules need to counter. + resource_spec_modules (list): Used for specify modules need to counter. Defaults to list(). num_candidates (int): The number of the candidates consist of samples from supernet and itself. Defaults to 1000. @@ -142,7 +142,7 @@ def __init__(self, score_key: str = 'accuracy/top1', flops_range: Optional[Tuple[float, float]] = (0., 330), resource_input_shape: Tuple = (1, 3, 224, 224), - spec_modules: List = [], + resource_spec_modules: List = [], num_candidates: int = 1000, num_samples: int = 10, top_k: int = 5, @@ -166,7 +166,6 @@ def __init__(self, self.score_key = score_key self.flops_range = flops_range - self.spec_modules = spec_modules self.num_candidates = num_candidates self.num_samples = num_samples self.top_k = top_k @@ -180,7 +179,9 @@ def __init__(self, self.candidates = Candidates() self.top_k_candidates = Candidates() - self.estimator = ResourceEstimator(input_shape=resource_input_shape) + self.estimator = ResourceEstimator( + input_shape=resource_input_shape, + flops_params_cfg=dict(spec_modules=resource_spec_modules)) def run(self) -> None: """Launch training.""" @@ -328,8 +329,7 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: fix_mutable = export_fix_subnet(self.model) copied_model = copy.deepcopy(self.model) load_fix_subnet(copied_model, fix_mutable) - results = self.estimator.estimate( - copied_model, spec_modules=self.spec_modules, as_strings=False) + results = self.estimator.estimate(copied_model) flops = results['flops'] if self.flops_range[0] <= flops <= self.flops_range[1]: # type: ignore diff --git a/mmrazor/models/task_modules/estimators/base_estimator.py b/mmrazor/models/task_modules/estimators/base_estimator.py index 497b19fe0..1a6f69264 100644 --- a/mmrazor/models/task_modules/estimators/base_estimator.py +++ b/mmrazor/models/task_modules/estimators/base_estimator.py @@ -23,22 +23,29 @@ def __init__(self, input_shape: Tuple = (1, 3, 224, 224), units: Dict = dict(), as_strings: bool = False): - assert len(input_shape) == 4, ( - f'The length of input_shape must be 4. Got {len(input_shape)}.') + assert len(input_shape) in [ + 3, 4, 5 + ], ('The length of input_shape must be in [3, 4, 5]. ' + f'Got `{len(input_shape)}`.') self.input_shape = input_shape self.units = units self.as_strings = as_strings @abstractmethod - def estimate(self, model: torch.nn.Module, - **kwargs) -> Dict[str, Union[float, str]]: + def estimate(self, + model: torch.nn.Module, + flops_params_cfg: dict = None, + latency_cfg: dict = None) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. Args: model: The measured model. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. + latency_cfg (dict): Cfg for estimating latency. Default to None. Returns: - Dict[str, float]): A dict that contains resource results(flops, - params and latency). + Dict[str, Union[float, str]]): A dict that contains the resource + results(FLOPs, params and latency). """ pass diff --git a/mmrazor/models/task_modules/estimators/counters/__init__.py b/mmrazor/models/task_modules/estimators/counters/__init__.py index 0a6adee48..721987ec1 100644 --- a/mmrazor/models/task_modules/estimators/counters/__init__.py +++ b/mmrazor/models/task_modules/estimators/counters/__init__.py @@ -1,10 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .flops_params_counter import (get_model_complexity_info, - params_units_convert) -from .latency_counter import repeat_measure_inference_speed +from .flops_params_counter import get_model_flops_params +from .latency_counter import get_model_latency from .op_counters import * # noqa: F401,F403 -__all__ = [ - 'get_model_complexity_info', 'params_units_convert', - 'repeat_measure_inference_speed' -] +__all__ = ['get_model_flops_params', 'get_model_latency'] diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index f666b9771..ac37e1c39 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -9,21 +9,21 @@ from mmrazor.registry import TASK_UTILS -def get_model_complexity_info(model, - input_shape=(1, 3, 224, 224), - spec_modules=[], - disabled_counters=[], - print_per_layer_stat=False, - units=dict(flops='M', params='M'), - as_strings=False, - seperate_return: bool = False, - input_constructor=None, - flush=False, - ost=sys.stdout): - """Get complexity information of a model. This method can calculate FLOPs - and parameter counts of a model with corresponding input shape. It can also - print complexity information for each layer in a model. Supported layers - are listed as below: +def get_model_flops_params(model, + input_shape=(1, 3, 224, 224), + spec_modules=[], + disabled_counters=[], + print_per_layer_stat=False, + units=dict(flops='M', params='M'), + as_strings=False, + seperate_return: bool = False, + input_constructor=None, + flush=False, + ost=sys.stdout): + """Get FLOPs and parameters of a model. This method can calculate FLOPs and + parameter counts of a model with corresponding input shape. It can also + print FLOPs and params for each layer in a model. Supported layers are + listed as below: - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``, @@ -48,7 +48,7 @@ def get_model_complexity_info(model, e.g., ['backbone', 'head'], ['backbone.layer1']. Default to []. disabled_counters (list): One can limit which ops' spec would be calculated. Default to []. - print_per_layer_stat (bool): Whether to print complexity information + print_per_layer_stat (bool): Whether to print FLOPs and params for each layer in a model. Default to True. units (dict): A dict including converted FLOPs and params units. Default to dict(flops='M', params='M'). @@ -72,12 +72,8 @@ def get_model_complexity_info(model, assert len(input_shape) >= 1 assert isinstance(model, nn.Module) if seperate_return and not len(spec_modules): - raise AssertionError(f'seperate_return can only be set to True when ' - f'spec_modules are not empty. Got spec_modules=' - f'{spec_modules}.') - if as_strings: - flops_suffix = ' ' + units['flops'] + 'FLOPs' if units else ' FLOPs' - params_suffix = ' ' + units['params'] if units else '' + raise AssertionError('`seperate_return` can only be set to True when ' + '`spec_modules` are not empty.') flops_params_model = add_flops_params_counting_methods(model) flops_params_model.eval() @@ -113,13 +109,17 @@ def get_model_complexity_info(model, flops_count = params_units_convert(flops_count, units['flops']) params_count = params_units_convert(params_count, units['params']) + if as_strings: + flops_suffix = ' ' + units['flops'] + 'FLOPs' if units else ' FLOPs' + params_suffix = ' ' + units['params'] if units else '' + if len(spec_modules): flops_count, params_count = 0.0, 0.0 module_names = [name for name, _ in flops_params_model.named_modules()] for module in spec_modules: assert module in module_names, \ f'All modules in spec_modules should be in the measured ' \ - f'flops_params_model. Got module {module} in spec_modules.' + f'flops_params_model. Got module `{module}` in spec_modules.' spec_modules_resources: Dict[str, dict] = dict() accumulate_sub_module_flops_params(flops_params_model, units=units) for name, module in flops_params_model.named_modules(): @@ -221,8 +221,8 @@ def print_model_with_flops_params(model, >>> return x >>> model = ExampleModel() >>> x = (3, 16, 16) - to print the complexity information state for each layer, you can use - >>> get_model_complexity_info(model, x) + to print the FLOPs and params state for each layer, you can use + >>> get_model_flops_params(model, x) or directly use >>> print_model_with_flops_params(model, 4579784.0, 37361) ExampleModel( diff --git a/mmrazor/models/task_modules/estimators/counters/latency_counter.py b/mmrazor/models/task_modules/estimators/counters/latency_counter.py index 58c888b5c..c86b4c849 100644 --- a/mmrazor/models/task_modules/estimators/counters/latency_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/latency_counter.py @@ -1,70 +1,74 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging import time -from typing import Tuple +from typing import Tuple, Union import torch from mmengine.logging import print_log -def repeat_measure_inference_speed(model: torch.nn.Module, - input_shape: Tuple = (1, 3, 224, 224), - latency_max_iter: int = 100, - latency_num_warmup: int = 5, - latency_log_interval: int = 100, - latency_repeat_num: int = 1, - unit: str = 'ms', - as_strings: bool = False) -> float: +def get_model_latency(model: torch.nn.Module, + input_shape: Tuple = (1, 3, 224, 224), + unit: str = 'ms', + as_strings: bool = False, + max_iter: int = 100, + num_warmup: int = 5, + log_interval: int = 100, + repeat_num: int = 1) -> Union[float, str]: """Repeat speed measure for multi-times to get more precise results. Args: model (torch.nn.Module): The measured model. input_shape (tuple): Input shape (including batchsize) used for calculation. Default to (1, 3, 224, 224). - latency_max_iter (Optional[int]): Max iteration num for the - measurement. Default to 100. - latency_num_warmup (Optional[int]): Iteration num for warm-up stage. + max_iter (Optional[int]): Max iteration num for the measurement. + Default to 100. + num_warmup (Optional[int]): Iteration num for warm-up stage. Default to 5. - latency_log_interval (Optional[int]): Interval num for logging the - results. Default to 100. - latency_repeat_num (Optional[int]): Num of times to repeat the - measurement. Default to 1. + log_interval (Optional[int]): Interval num for logging the results. + Default to 100. + repeat_num (Optional[int]): Num of times to repeat the measurement. + Default to 1. Returns: - fps (float): The measured inference speed of the model. + latency (Union[float, str]): The measured inference speed of the model. + if ``as_strings=True``, it will return latency in string format. """ - assert latency_repeat_num >= 1 + assert repeat_num >= 1 fps_list = [] - for _ in range(latency_repeat_num): - + for _ in range(repeat_num): fps_list.append( - measure_inference_speed(model, input_shape, latency_max_iter, - latency_num_warmup, latency_log_interval)) + _get_model_latency(model, input_shape, max_iter, num_warmup, + log_interval)) + + latency = round(1000 / fps_list[0], 1) - if latency_repeat_num > 1: - fps_list_ = [round(fps, 1) for fps in fps_list] + if repeat_num > 1: + _fps_list = [round(fps, 1) for fps in fps_list] times_per_img_list = [round(1000 / fps, 1) for fps in fps_list] - mean_fps_ = sum(fps_list_) / len(fps_list_) + _mean_fps = sum(_fps_list) / len(_fps_list) mean_times_per_img = sum(times_per_img_list) / len(times_per_img_list) print_log( - f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, ' + f'Overall fps: {_fps_list}[{_mean_fps:.1f}] img / s, ' f'times per image: ' f'{times_per_img_list}[{mean_times_per_img:.1f}] ms/img', logger='current', level=logging.DEBUG) - return mean_times_per_img + latency = mean_times_per_img + + if as_strings: + latency = str(latency) + ' ' + unit # type: ignore - latency = round(1000 / fps_list[0], 1) return latency -def measure_inference_speed(model: torch.nn.Module, - input_shape: Tuple = (1, 3, 224, 224), - latency_max_iter: int = 100, - latency_num_warmup: int = 5, - latency_log_interval: int = 100) -> float: +def _get_model_latency(model: torch.nn.Module, + input_shape: Tuple = (1, 3, 224, 224), + max_iter: int = 100, + num_warmup: int = 5, + log_interval: int = 100) -> float: """Measure inference speed on GPU devices. Args: @@ -89,8 +93,9 @@ def measure_inference_speed(model: torch.nn.Module, device = 'cuda' else: raise NotImplementedError('To use cpu to test latency not supported.') - # benchmark with {latency_max_iter} image and take the average - for i in range(1, latency_max_iter): + + # benchmark with {max_iter} image and take the average + for i in range(1, max_iter): if device == 'cuda': data = torch.rand(input_shape).cuda() torch.cuda.synchronize() @@ -102,19 +107,19 @@ def measure_inference_speed(model: torch.nn.Module, torch.cuda.synchronize() elapsed = time.perf_counter() - start_time - if i >= latency_num_warmup: + if i >= num_warmup: pure_inf_time += elapsed - if (i + 1) % latency_log_interval == 0: - fps = (i + 1 - latency_num_warmup) / pure_inf_time + if (i + 1) % log_interval == 0: + fps = (i + 1 - num_warmup) / pure_inf_time print_log( - f'Done image [{i + 1:<3}/ {latency_max_iter}], ' + f'Done image [{i + 1:<3}/ {max_iter}], ' f'fps: {fps:.1f} img / s, ' f'times per image: {1000 / fps:.1f} ms / img', logger='current', level=logging.DEBUG) - if (i + 1) == latency_max_iter: - fps = (i + 1 - latency_num_warmup) / pure_inf_time + if (i + 1) == max_iter: + fps = (i + 1 - num_warmup) / pure_inf_time print_log( f'Overall fps: {fps:.1f} img / s, ' f'times per image: {1000 / fps:.1f} ms / img', diff --git a/mmrazor/models/task_modules/estimators/resource_estimator.py b/mmrazor/models/task_modules/estimators/resource_estimator.py index eaf5501f6..19eddd428 100644 --- a/mmrazor/models/task_modules/estimators/resource_estimator.py +++ b/mmrazor/models/task_modules/estimators/resource_estimator.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch.nn from mmrazor.registry import TASK_UTILS from .base_estimator import BaseEstimator -from .counters import get_model_complexity_info, repeat_measure_inference_speed +from .counters import get_model_flops_params, get_model_latency @TASK_UTILS.register_module() @@ -39,12 +39,17 @@ class ResourceEstimator(BaseEstimator): Examples: >>> # direct calculate resource consume of nn.Conv2d >>> conv2d = nn.Conv2d(3, 32, 3) - >>> estimator = ResourceEstimator() - >>> estimator.estimate( - ... model=conv2d, - ... input_shape=(1, 3, 64, 64)) + >>> estimator = ResourceEstimator(input_shape=(1, 3, 64, 64)) + >>> estimator.estimate(model=conv2d) {'flops': 3.444, 'params': 0.001, 'latency': 0.0} + >>> # direct calculate resource consume of nn.Conv2d + >>> conv2d = nn.Conv2d(3, 32, 3) + >>> estimator = ResourceEstimator() + >>> flops_params_cfg = dict(input_shape=(1, 3, 32, 32)) + >>> estimator.estimate(model=conv2d, flops_params_cfg) + {'flops': 0.806, 'params': 0.001, 'latency': 0.0} + >>> # calculate resources of custom modules >>> class CustomModule(nn.Module): ... @@ -63,16 +68,14 @@ class ResourceEstimator(BaseEstimator): ... module.__params__ += 700000 ... >>> model = CustomModule() - >>> estimator.estimate( - ... model=model, - ... input_shape=(1, 3, 64, 64)) + >>> flops_params_cfg = dict(input_shape=(1, 3, 64, 64)) + >>> estimator.estimate(model=model, flops_params_cfg) {'flops': 1.0, 'params': 0.7, 'latency': 0.0} ... >>> # calculate resources of custom modules with disable_counters - >>> estimator.estimate( - ... model=model, - ... input_shape=(1, 3, 64, 64), - ... disabled_counters=['CustomModuleCounter']) + >>> flops_params_cfg = dict(input_shape=(1, 3, 64, 64), + ... disabled_counters=['CustomModuleCounter']) + >>> estimator.estimate(model=model, flops_params_cfg) {'flops': 0.0, 'params': 0.0, 'latency': 0.0} >>> # calculate resources of mmrazor.models @@ -85,52 +88,69 @@ def __init__( input_shape: Tuple = (1, 3, 224, 224), units: Dict = dict(flops='M', params='M', latency='ms'), as_strings: bool = False, - spec_modules: List[str] = [], - disabled_counters: List[str] = [], - measure_latency: bool = False, - latency_max_iter: int = 100, - latency_num_warmup: int = 5, - latency_log_interval: int = 100, - latency_repeat_num: int = 1, + flops_params_cfg: Optional[dict] = None, + latency_cfg: Optional[dict] = None, ): super().__init__(input_shape, units, as_strings) - self.spec_modules = spec_modules - self.disabled_counters = disabled_counters - - self.measure_latency = measure_latency - self.latency_max_iter = latency_max_iter - self.latency_num_warmup = latency_num_warmup - self.latency_log_interval = latency_log_interval - self.latency_repeat_num = latency_repeat_num + if not isinstance(units, dict): + raise TypeError('units for estimator should be a dict', + f'but got `{type(units)}`') + for unit_key in units: + if unit_key not in ['flops', 'params', 'latency']: + raise KeyError(f'Got invalid key `{unit_key}` in units. ', + 'Should be flops, params or latency.') + if flops_params_cfg: + self.flops_params_cfg = flops_params_cfg + else: + self.flops_params_cfg = dict() + self.latency_cfg = latency_cfg if latency_cfg else dict() - def estimate(self, model: torch.nn.Module, - **kwargs) -> Dict[str, Union[float, str]]: + def estimate(self, + model: torch.nn.Module, + flops_params_cfg: dict = None, + latency_cfg: dict = None) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. + This method will first parse the merged :attr:`self.flops_params_cfg` + and the :attr:`self.latency_cfg` to check whether the keys are valid. + Args: model: The measured model. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. + latency_cfg (dict): Cfg for estimating latency. Default to None. + + NOTE: If the `flops_params_cfg` and `latency_cfg` are both None, + this method will only estimate FLOPs/params with default settings. Returns: - Dict[str, str]): A dict that containing resource results(flops, - params and latency). + Dict[str, Union[float, str]]): A dict that contains the resource + results(FLOPs, params and latency). """ - latency_cfg = dict() resource_metrics = dict() - for key in self.__init__.__code__.co_varnames[1:]: # type: ignore - kwargs.setdefault(key, getattr(self, key)) - if 'latency' in key: - latency_cfg[key] = kwargs.pop(key) - latency_cfg['unit'] = kwargs['units'].get('latency') - latency_cfg['as_strings'] = kwargs['as_strings'] - latency_cfg['input_shape'] = kwargs['input_shape'] + measure_latency = True if latency_cfg else False - model.eval() - flops, params = get_model_complexity_info(model, **kwargs) + if flops_params_cfg: + flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg} + self._check_flops_params_cfg(flops_params_cfg) + flops_params_cfg = self._set_default_resource_params( + flops_params_cfg) + else: + flops_params_cfg = self.flops_params_cfg - if latency_cfg['measure_latency']: - latency = repeat_measure_inference_speed(model, **latency_cfg) + if latency_cfg: + latency_cfg = {**self.latency_cfg, **latency_cfg} + self._check_latency_cfg(latency_cfg) + latency_cfg = self._set_default_resource_params(latency_cfg) + else: + latency_cfg = self.latency_cfg + + model.eval() + flops, params = get_model_flops_params(model, **flops_params_cfg) + if measure_latency: + latency = get_model_latency(model, **latency_cfg) else: - latency = '0.0 ms' if kwargs['as_strings'] else 0.0 # type: ignore + latency = '0.0 ms' if self.as_strings else 0.0 # type: ignore resource_metrics.update({ 'flops': flops, @@ -139,28 +159,70 @@ def estimate(self, model: torch.nn.Module, }) return resource_metrics - def estimate_separation_modules(self, model: torch.nn.Module, - **kwargs) -> Dict[str, Union[float, str]]: - """Estimate the resources(flops/params/latency) of the spec modules. + def estimate_separation_modules( + self, + model: torch.nn.Module, + flops_params_cfg: dict = None) -> Dict[str, Union[float, str]]: + """Estimate FLOPs and params of the spec modules with separate return. Args: model: The measured model. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. Returns: - Dict[str, float]): A dict that containing resource results(flops, - params) of each modules in kwargs['spec_modules']. + Dict[str, Union[float, str]]): A dict that contains the FLOPs and + params results (string | float format) of each modules in the + ``flops_params_cfg['spec_modules']``. """ - for key in self.__init__.__code__.co_varnames[1:]: # type: ignore - kwargs.setdefault(key, getattr(self, key)) - # TODO: support speed estimation for separation modules. - if 'latency' in key: - kwargs.pop(key) + if flops_params_cfg: + flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg} + self._check_flops_params_cfg(flops_params_cfg) + flops_params_cfg = self._set_default_resource_params( + flops_params_cfg) + else: + flops_params_cfg = self.flops_params_cfg + flops_params_cfg['seperate_return'] = True - assert len(kwargs['spec_modules']), ( - f'spec_modules can not be empty when calling ' - f'{self.__class__.__name__}.estimate_separation_modules().') - kwargs['seperate_return'] = True + assert len(flops_params_cfg['spec_modules']), ( + 'spec_modules can not be empty when calling ' + f'`estimate_separation_modules` of {self.__class__.__name__} ') model.eval() - spec_modules_resources = get_model_complexity_info(model, **kwargs) + spec_modules_resources = get_model_flops_params( + model, **flops_params_cfg) return spec_modules_resources + + def _check_flops_params_cfg(self, flops_params_cfg: dict) -> None: + """Check the legality of ``flops_params_cfg``. + + Args: + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + """ + for key in flops_params_cfg: + if key not in get_model_flops_params.__code__.co_varnames[ + 1:]: # type: ignore + raise KeyError(f'Got invalid key `{key}` in flops_params_cfg.') + + def _check_latency_cfg(self, latency_cfg: dict) -> None: + """Check the legality of ``latency_cfg``. + + Args: + latency_cfg (dict): Cfg for estimating latency. + """ + for key in latency_cfg: + if key not in get_model_latency.__code__.co_varnames[ + 1:]: # type: ignore + raise KeyError(f'Got invalid key `{key}` in latency_cfg.') + + def _set_default_resource_params(self, cfg: dict) -> dict: + """Set default attributes for the input cfgs. + + Args: + cfg (dict): flops_params_cfg or latency_cfg. + """ + default_common_settings = ['input_shape', 'units', 'as_strings'] + for key in default_common_settings: + if key not in cfg: + cfg[key] = getattr(self, key) + return cfg diff --git a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py index ee6ac6bd0..60bcef4ba 100644 --- a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py +++ b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py @@ -118,8 +118,9 @@ def sample_choice(self, model: Module) -> None: def test_estimate(self) -> None: fool_conv2d = FoolConv2d() + flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) results = estimator.estimate( - model=fool_conv2d, input_shape=(1, 3, 224, 224)) + model=fool_conv2d, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] @@ -128,8 +129,9 @@ def test_estimate(self) -> None: def test_register_module(self) -> None: fool_add_constant = FoolConvModule() + flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) results = estimator.estimate( - model=fool_add_constant, input_shape=(1, 3, 224, 224)) + model=fool_add_constant, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] @@ -138,10 +140,11 @@ def test_register_module(self) -> None: def test_disable_sepc_counter(self) -> None: fool_add_constant = FoolConvModule() - rest_results = estimator.estimate( - model=fool_add_constant, + flops_params_cfg = dict( input_shape=(1, 3, 224, 224), disabled_counters=['FoolAddConstantCounter']) + rest_results = estimator.estimate( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) rest_flops_count = rest_results['flops'] rest_params_count = rest_results['params'] @@ -150,10 +153,11 @@ def test_disable_sepc_counter(self) -> None: def test_estimate_spec_module(self) -> None: fool_add_constant = FoolConvModule() - results = estimator.estimate( - model=fool_add_constant, + flops_params_cfg = dict( input_shape=(1, 3, 224, 224), spec_modules=['add_constant', 'conv2d']) + results = estimator.estimate( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] @@ -162,39 +166,39 @@ def test_estimate_spec_module(self) -> None: def test_estimate_separation_modules(self) -> None: fool_add_constant = FoolConvModule() + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), spec_modules=['add_constant']) results = estimator.estimate_separation_modules( - model=fool_add_constant, - input_shape=(1, 3, 224, 224), - spec_modules=['add_constant']) + model=fool_add_constant, flops_params_cfg=flops_params_cfg) self.assertGreater(results['add_constant']['flops'], 0) with pytest.raises(AssertionError): + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), spec_modules=['backbone']) results = estimator.estimate_separation_modules( - model=fool_add_constant, - input_shape=(1, 3, 224, 224), - spec_modules=['backbone']) + model=fool_add_constant, flops_params_cfg=flops_params_cfg) with pytest.raises(AssertionError): + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), spec_modules=[]) results = estimator.estimate_separation_modules( - model=fool_add_constant, - input_shape=(1, 3, 224, 224), - spec_modules=[]) + model=fool_add_constant, flops_params_cfg=flops_params_cfg) def test_estimate_subnet(self) -> None: - input_shape = (1, 3, 224, 224) + flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) model = MODELS.build(BACKBONE_CFG) self.sample_choice(model) copied_model = copy.deepcopy(model) results = estimator.estimate( - model=copied_model, input_shape=input_shape) + model=copied_model, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] fix_subnet = export_fix_subnet(model) load_fix_subnet(copied_model, fix_subnet) subnet_results = estimator.estimate( - model=copied_model, input_shape=input_shape) + model=copied_model, flops_params_cfg=flops_params_cfg) subnet_flops_count = subnet_results['flops'] subnet_params_count = subnet_results['params'] @@ -203,8 +207,8 @@ def test_estimate_subnet(self) -> None: # test whether subnet estimate will affect original model copied_model = copy.deepcopy(model) - results_after_estimate = \ - estimator.estimate(model=copied_model, input_shape=input_shape) + results_after_estimate = estimator.estimate( + model=copied_model, flops_params_cfg=flops_params_cfg) flops_count_after_estimate = results_after_estimate['flops'] params_count_after_estimate = results_after_estimate['params'] From 2992afb2b7764831718c473b4421be1a967807d7 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Tue, 13 Sep 2022 16:23:05 +0800 Subject: [PATCH 3/7] fix docstrings --- .../counters/flops_params_counter.py | 2 ++ .../estimators/counters/latency_counter.py | 13 +++++++----- .../estimators/resource_estimator.py | 21 ++++--------------- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index ac37e1c39..f31208248 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -54,6 +54,8 @@ def get_model_flops_params(model, Default to dict(flops='M', params='M'). as_strings (bool): Output FLOPs and params counts in a string form. Default to True. + seperate_return (bool): Whether to return the resource information + separately. Default to False. input_constructor (None | callable): If specified, it takes a callable method that generates input. otherwise, it will generate a random tensor with input shape to calculate FLOPs. Default to None. diff --git a/mmrazor/models/task_modules/estimators/counters/latency_counter.py b/mmrazor/models/task_modules/estimators/counters/latency_counter.py index c86b4c849..a4241e313 100644 --- a/mmrazor/models/task_modules/estimators/counters/latency_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/latency_counter.py @@ -21,6 +21,9 @@ def get_model_latency(model: torch.nn.Module, model (torch.nn.Module): The measured model. input_shape (tuple): Input shape (including batchsize) used for calculation. Default to (1, 3, 224, 224). + unit (str): Unit of latency in string format. Default to 'ms'. + as_strings (bool): Output latency counts in a string form. + Default to False. max_iter (Optional[int]): Max iteration num for the measurement. Default to 100. num_warmup (Optional[int]): Iteration num for warm-up stage. @@ -75,12 +78,12 @@ def _get_model_latency(model: torch.nn.Module, model (torch.nn.Module): The measured model. input_shape (tuple): Input shape (including batchsize) used for calculation. Default to (1, 3, 224, 224). - latency_max_iter (Optional[int]): Max iteration num for the - measurement. Default to 100. - latency_num_warmup (Optional[int]): Iteration num for warm-up stage. + max_iter (Optional[int]): Max iteration num for the measurement. + Default to 100. + num_warmup (Optional[int]): Iteration num for warm-up stage. Default to 5. - latency_log_interval (Optional[int]): Interval num for logging the - results. Default to 100. + log_interval (Optional[int]): Interval num for logging the results. + Default to 100. Returns: fps (float): The measured inference speed of the model. diff --git a/mmrazor/models/task_modules/estimators/resource_estimator.py b/mmrazor/models/task_modules/estimators/resource_estimator.py index 19eddd428..ac5292d0c 100644 --- a/mmrazor/models/task_modules/estimators/resource_estimator.py +++ b/mmrazor/models/task_modules/estimators/resource_estimator.py @@ -19,22 +19,9 @@ class ResourceEstimator(BaseEstimator): Default to dict(flops='M', params='M', latency='ms'). as_strings (bool): Output FLOPs/params/latency counts in a string form. Default to False. - spec_modules (list): List of spec modules that needed to count. - e.g., ['backbone', 'head'], ['backbone.layer1']. Default to []. - disabled_counters (list): List of disabled spec op counters. - It contains the op counter names in estimator.op_counters that - are required to be disabled, e.g., ['BatchNorm2dCounter']. - Defaults to []. - measure_latency (bool): whether to measure inference speed or not. - Default to False. - latency_max_iter (Optional[int]): Max iteration num for the - measurement. Default to 100. - latency_num_warmup (Optional[int]): Iteration num for warm-up stage. - Default to 5. - latency_log_interval (Optional[int]): Interval num for logging the - results. Default to 100. - latency_repeat_num (Optional[int]): Num of times to repeat the - measurement. Default to 1. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. + latency_cfg (dict): Cfg for estimating latency. Default to None. Examples: >>> # direct calculate resource consume of nn.Conv2d @@ -98,7 +85,7 @@ def __init__( for unit_key in units: if unit_key not in ['flops', 'params', 'latency']: raise KeyError(f'Got invalid key `{unit_key}` in units. ', - 'Should be flops, params or latency.') + 'Should be `flops`, `params` or `latency`.') if flops_params_cfg: self.flops_params_cfg = flops_params_cfg else: From e67e611f2b6bddeda23230dae7d1036565d0cb3a Mon Sep 17 00:00:00 2001 From: humu789 Date: Tue, 13 Sep 2022 17:30:07 +0800 Subject: [PATCH 4/7] update search loop and config --- .../spos/spos_mobilenet_search_8xb128_in1k.py | 2 +- .../spos_shufflenet_search_8xb128_in1k.py | 2 +- .../detnas_frcnn_shufflenet_search_coco_1x.py | 4 +- .../engine/runner/evolution_search_loop.py | 48 ++++++++----------- mmrazor/engine/runner/subnet_sampler_loop.py | 41 +++++++--------- mmrazor/engine/runner/utils/__init__.py | 3 +- mmrazor/engine/runner/utils/check.py | 39 +++++++++++++++ 7 files changed, 81 insertions(+), 58 deletions(-) create mode 100644 mmrazor/engine/runner/utils/check.py diff --git a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py index f670e5c0c..4f5edb316 100644 --- a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 465 * 1e6), + flops_range=(0., 465.), score_key='accuracy/top1') diff --git a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py index 6f8dc9366..f3f963e40 100644 --- a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 330 * 1e6), + flops_range=(0., 330.), score_key='accuracy/top1') diff --git a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py index 0bd3b71fe..d1dd1637a 100644 --- a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py +++ b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py @@ -13,5 +13,5 @@ num_mutation=20, num_crossover=20, mutate_prob=0.1, - flops_range=None, - score_key='bbox_mAP') + flops_range=(0., 300.), + score_key='coco/bbox_mAP') diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 263fb76e1..008d6d018 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import os import os.path as osp import random @@ -16,9 +15,9 @@ from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet +from mmrazor.structures import Candidates, export_fix_subnet from mmrazor.utils import SupportRandomSubnet -from .utils import crossover +from .utils import crossover, check_subnet_flops @LOOPS.register_module() @@ -42,12 +41,10 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): num_crossover (int): The number of candidates got by crossover. Defaults to 25. mutate_prob (float): The probability of mutation. Defaults to 0.1. - flops_range (tuple, optional): flops_range to be used for screening - candidates. - resource_input_shape (Tuple): Input shape when measuring flops. - Default to (1, 3, 224, 224). - resource_spec_modules (list): Used for specify modules need to counter. - Defaults to list(). + flops_range (tuple, optional): It is used for screening candidates. + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. init_candidates (str, optional): The candidates file path, which is @@ -67,9 +64,8 @@ def __init__(self, num_mutation: int = 25, num_crossover: int = 25, mutate_prob: float = 0.1, - flops_range: Optional[Tuple[float, float]] = (0., 330), - resource_input_shape: Tuple = (1, 3, 224, 224), - resource_spec_modules: List = [], + flops_range: Optional[Tuple[float, float]] = (0., 330.), + resource_estimator_cfg: Optional[dict] = None, score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: super().__init__(runner, dataloader, max_epochs) @@ -103,9 +99,10 @@ def __init__(self, correct init candidates file' self.top_k_candidates = Candidates() - self.estimator = ResourceEstimator( - input_shape=resource_input_shape, - flops_params_cfg=dict(spec_modules=resource_spec_modules)) + if resource_estimator_cfg is None: + self.estimator = ResourceEstimator() + else: + self.estimator = ResourceEstimator(**resource_estimator_cfg) if self.runner.distributed: self.model = runner.model.module @@ -304,17 +301,10 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: Returns: bool: The result of checking. """ - if self.flops_range is None: - return True - - self.model.set_subnet(random_subnet) - fix_mutable = export_fix_subnet(self.model) - copied_model = copy.deepcopy(self.model) - load_fix_subnet(copied_model, fix_mutable) - results = self.estimator.estimate(copied_model) - flops = results['flops'] - - if self.flops_range[0] <= flops <= self.flops_range[1]: # type: ignore - return True - else: - return False + is_pass = check_subnet_flops( + model=self.model, + subnet=random_subnet, + estimator=self.estimator, + flops_range=self.flops_range) + + return is_pass diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index 9beeaf2b4..ebf07ced2 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -15,8 +15,9 @@ from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet +from mmrazor.structures import Candidates, export_fix_subnet from mmrazor.utils import SupportRandomSubnet +from .utils import check_subnet_flops class BaseSamplerTrainLoop(IterBasedTrainLoop): @@ -103,10 +104,9 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. flops_range (dict): Constraints to be used for screening candidates. - resource_input_shape (Tuple): Input shape when measuring flops. - Default to (1, 3, 224, 224). - resource_spec_modules (list): Used for specify modules need to counter. - Defaults to list(). + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. num_candidates (int): The number of the candidates consist of samples from supernet and itself. Defaults to 1000. num_samples (int): The number of sample in each sampling subnet. @@ -141,8 +141,7 @@ def __init__(self, val_interval: int = 1000, score_key: str = 'accuracy/top1', flops_range: Optional[Tuple[float, float]] = (0., 330), - resource_input_shape: Tuple = (1, 3, 224, 224), - resource_spec_modules: List = [], + resource_estimator_cfg: Optional[dict] = None, num_candidates: int = 1000, num_samples: int = 10, top_k: int = 5, @@ -179,9 +178,10 @@ def __init__(self, self.candidates = Candidates() self.top_k_candidates = Candidates() - self.estimator = ResourceEstimator( - input_shape=resource_input_shape, - flops_params_cfg=dict(spec_modules=resource_spec_modules)) + if resource_estimator_cfg is None: + self.estimator = ResourceEstimator() + else: + self.estimator = ResourceEstimator(**resource_estimator_cfg) def run(self) -> None: """Launch training.""" @@ -322,20 +322,13 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: Returns: bool: The result of checking. """ - if self.flops_range is None: - return True - - self.model.set_subnet(random_subnet) - fix_mutable = export_fix_subnet(self.model) - copied_model = copy.deepcopy(self.model) - load_fix_subnet(copied_model, fix_mutable) - results = self.estimator.estimate(copied_model) - flops = results['flops'] - - if self.flops_range[0] <= flops <= self.flops_range[1]: # type: ignore - return True - else: - return False + is_pass = check_subnet_flops( + model=self.model, + subnet=random_subnet, + estimator=self.estimator, + flops_range=self.flops_range) + + return is_pass def _save_candidates(self) -> None: """Save the candidates to init the next searching.""" diff --git a/mmrazor/engine/runner/utils/__init__.py b/mmrazor/engine/runner/utils/__init__.py index 7aaf29539..2c1fdc5fa 100644 --- a/mmrazor/engine/runner/utils/__init__.py +++ b/mmrazor/engine/runner/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .genetic import crossover +from .check import check_subnet_flops -__all__ = ['crossover'] +__all__ = ['crossover', 'check_subnet_flops'] diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py new file mode 100644 index 000000000..2fcf7cdf8 --- /dev/null +++ b/mmrazor/engine/runner/utils/check.py @@ -0,0 +1,39 @@ +import copy +import torch.nn as nn +from typing import Tuple, Optional +from mmdet.models.detectors import BaseDetector +from mmrazor.structures import export_fix_subnet, load_fix_subnet +from mmrazor.utils import SupportRandomSubnet + +def check_subnet_flops( + model: nn.Module, + subnet: SupportRandomSubnet, + estimator: callable, + flops_range: Optional[Tuple[float, float]]=None + ) -> bool: + """Check whether is beyond flops constraints. + + Returns: + bool: The result of checking. + """ + if flops_range is None: + return True + + assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture') + model.set_subnet(subnet) + fix_mutable = export_fix_subnet(model) + copied_model = copy.deepcopy(model) + load_fix_subnet(copied_model, fix_mutable) + + model_to_check = model.architecture + if isinstance(model_to_check, BaseDetector): + results = estimator.estimate(model_to_check.backbone) + else: + results = estimator.estimate(model_to_check) + + flops = results['flops'] + flops_mix, flops_max = flops_range + if flops_mix <= flops <= flops_max: # type: ignore + return True + else: + return False \ No newline at end of file From 3e6146c64811f4d169120bf48a79b291b998a4fa Mon Sep 17 00:00:00 2001 From: humu789 Date: Tue, 13 Sep 2022 17:51:16 +0800 Subject: [PATCH 5/7] fix lint --- .../engine/runner/evolution_search_loop.py | 6 ++--- mmrazor/engine/runner/subnet_sampler_loop.py | 7 +++-- mmrazor/engine/runner/utils/__init__.py | 2 +- mmrazor/engine/runner/utils/check.py | 26 +++++++++++-------- 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 008d6d018..a9a76b383 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -17,7 +17,7 @@ from mmrazor.registry import LOOPS from mmrazor.structures import Candidates, export_fix_subnet from mmrazor.utils import SupportRandomSubnet -from .utils import crossover, check_subnet_flops +from .utils import check_subnet_flops, crossover @LOOPS.register_module() @@ -43,7 +43,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): mutate_prob (float): The probability of mutation. Defaults to 0.1. flops_range (tuple, optional): It is used for screening candidates. resource_estimator_cfg (dict): The config for building estimator, which - is be used to estimate the flops of sampled subnet. Defaults to + is be used to estimate the flops of sampled subnet. Defaults to None, which means default config is used. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. @@ -306,5 +306,5 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: subnet=random_subnet, estimator=self.estimator, flops_range=self.flops_range) - + return is_pass diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index ebf07ced2..1127aab21 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import math import os import random @@ -15,7 +14,7 @@ from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, export_fix_subnet +from mmrazor.structures import Candidates from mmrazor.utils import SupportRandomSubnet from .utils import check_subnet_flops @@ -105,7 +104,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): candidates. Defaults to 'accuracy_top-1'. flops_range (dict): Constraints to be used for screening candidates. resource_estimator_cfg (dict): The config for building estimator, which - is be used to estimate the flops of sampled subnet. Defaults to + is be used to estimate the flops of sampled subnet. Defaults to None, which means default config is used. num_candidates (int): The number of the candidates consist of samples from supernet and itself. Defaults to 1000. @@ -327,7 +326,7 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: subnet=random_subnet, estimator=self.estimator, flops_range=self.flops_range) - + return is_pass def _save_candidates(self) -> None: diff --git a/mmrazor/engine/runner/utils/__init__.py b/mmrazor/engine/runner/utils/__init__.py index 2c1fdc5fa..ec2f2cb29 100644 --- a/mmrazor/engine/runner/utils/__init__.py +++ b/mmrazor/engine/runner/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .genetic import crossover from .check import check_subnet_flops +from .genetic import crossover __all__ = ['crossover', 'check_subnet_flops'] diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py index 2fcf7cdf8..dee81e639 100644 --- a/mmrazor/engine/runner/utils/check.py +++ b/mmrazor/engine/runner/utils/check.py @@ -1,16 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. import copy +from typing import Optional, Tuple + import torch.nn as nn -from typing import Tuple, Optional -from mmdet.models.detectors import BaseDetector +from mmdet.models.detectors import BaseDetector + +from mmrazor.models import ResourceEstimator from mmrazor.structures import export_fix_subnet, load_fix_subnet from mmrazor.utils import SupportRandomSubnet + def check_subnet_flops( - model: nn.Module, - subnet: SupportRandomSubnet, - estimator: callable, - flops_range: Optional[Tuple[float, float]]=None - ) -> bool: + model: nn.Module, + subnet: SupportRandomSubnet, + estimator: ResourceEstimator, + flops_range: Optional[Tuple[float, float]] = None) -> bool: """Check whether is beyond flops constraints. Returns: @@ -27,13 +31,13 @@ def check_subnet_flops( model_to_check = model.architecture if isinstance(model_to_check, BaseDetector): - results = estimator.estimate(model_to_check.backbone) + results = estimator.estimate(model=model_to_check.backbone) else: - results = estimator.estimate(model_to_check) - + results = estimator.estimate(model=model_to_check) + flops = results['flops'] flops_mix, flops_max = flops_range if flops_mix <= flops <= flops_max: # type: ignore return True else: - return False \ No newline at end of file + return False From 6b16de4d13ec1e53747926655475f5c142eb1ad1 Mon Sep 17 00:00:00 2001 From: humu789 Date: Tue, 13 Sep 2022 18:37:39 +0800 Subject: [PATCH 6/7] update unittest --- .../test_evolution_search_loop.py | 5 ++- .../test_runners/test_subnet_sampler_loop.py | 22 +++---------- tests/test_runners/test_utils/test_check.py | 31 +++++++++++++++++++ 3 files changed, 37 insertions(+), 21 deletions(-) create mode 100644 tests/test_runners/test_utils/test_check.py diff --git a/tests/test_runners/test_evolution_search_loop.py b/tests/test_runners/test_evolution_search_loop.py index 8655c9c4a..f30019274 100644 --- a/tests/test_runners/test_evolution_search_loop.py +++ b/tests/test_runners/test_evolution_search_loop.py @@ -112,8 +112,7 @@ def test_init(self): self.assertEqual(loop.candidates, fake_candidates) @patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet') - @patch('mmrazor.models.task_modules.ResourceEstimator.estimate') - def test_run_epoch(self, mock_flops, mock_export_fix_subnet): + def test_run_epoch(self, mock_export_fix_subnet): # test_run_epoch: distributed == False loop_cfg = copy.deepcopy(self.train_cfg) loop_cfg.runner = self.runner @@ -153,7 +152,7 @@ def test_run_epoch(self, mock_flops, mock_export_fix_subnet): self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} loop.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = dict(flops=10.0, params=2.0) + loop._check_constraints = MagicMock(return_value=True) mock_export_fix_subnet.return_value = fake_subnet loop.run_epoch() self.assertEqual(len(loop.candidates), 4) diff --git a/tests/test_runners/test_subnet_sampler_loop.py b/tests/test_runners/test_subnet_sampler_loop.py index 1271b864b..fca29b823 100644 --- a/tests/test_runners/test_subnet_sampler_loop.py +++ b/tests/test_runners/test_subnet_sampler_loop.py @@ -192,29 +192,15 @@ def test_sample_subnet(self): self.assertEqual(subnet, fake_subnet) self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1) - @patch('mmrazor.engine.runner.subnet_sampler_loop.export_fix_subnet') - @patch('mmrazor.models.task_modules.ResourceEstimator.estimate') - def test_run(self, mock_flops, mock_export_fix_subnet): - # test run with flops_range=None - cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_run1' - runner = Runner.from_cfg(cfg) - fake_subnet = {'1': 'choice1', '2': 'choice2'} - runner.model.sample_subnet = MagicMock(return_value=fake_subnet) - runner.train() - - self.assertEqual(runner.iter, runner.max_iters) - assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl')) - + def test_run(self): # test run with _check_constraints cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_run2' - cfg.train_cfg.flops_range = (0, 100) + cfg.experiment_name = 'test_run1' runner = Runner.from_cfg(cfg) fake_subnet = {'1': 'choice1', '2': 'choice2'} runner.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = dict(flops=10.0, params=2.0) - mock_export_fix_subnet.return_value = fake_subnet + loop = runner.build_train_loop(cfg.train_cfg) + loop._check_constraints = MagicMock(return_value=True) runner.train() self.assertEqual(runner.iter, runner.max_iters) diff --git a/tests/test_runners/test_utils/test_check.py b/tests/test_runners/test_utils/test_check.py new file mode 100644 index 000000000..31808fedf --- /dev/null +++ b/tests/test_runners/test_utils/test_check.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.engine.runner.utils import check_subnet_flops +from mmdet.models.detectors import BaseDetector +from unittest.mock import patch + +@patch('mmrazor.models.ResourceEstimator') +@patch('mmrazor.models.SPOS') +def test_check_subnet_flops(mock_model, mock_estimator): + # flops_range = None + flops_range = None + fake_subnet = {'1': 'choice1', '2': 'choice2'} + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, flops_range) + assert result is True + + # flops_range is not None + # architecturte is BaseDetector + flops_range = (0., 100.) + mock_model.architecture = BaseDetector + fake_results = {'flops': 50.} + mock_estimator.estimate.return_value = fake_results + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, flops_range) + assert result is True + + # flops_range is not None + # architecturte is BaseDetector + flops_range = (0., 100.) + fake_results = {'flops': -50.} + mock_estimator.estimate.return_value = fake_results + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, flops_range) + assert result is False + From 1fafeb12f005dac1c2470a4b76199d52a1a684eb Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Wed, 14 Sep 2022 20:14:39 +0800 Subject: [PATCH 7/7] decouple mmdet dependency and fix lint --- mmrazor/engine/runner/utils/check.py | 7 ++++++- .../single_stage_detector_loss_calculator.py | 7 ++++++- tests/test_runners/test_utils/test_check.py | 21 +++++++++++++------ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py index dee81e639..e2fdcfcc6 100644 --- a/mmrazor/engine/runner/utils/check.py +++ b/mmrazor/engine/runner/utils/check.py @@ -3,12 +3,17 @@ from typing import Optional, Tuple import torch.nn as nn -from mmdet.models.detectors import BaseDetector from mmrazor.models import ResourceEstimator from mmrazor.structures import export_fix_subnet, load_fix_subnet from mmrazor.utils import SupportRandomSubnet +try: + from mmdet.models.detectors import BaseDetector +except ImportError: + from mmrazor.utils import get_placeholder + BaseDetector = get_placeholder('mmdet') + def check_subnet_flops( model: nn.Module, diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py b/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py index 85f25eaff..5365831b9 100644 --- a/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py +++ b/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py @@ -1,9 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmdet.models import BaseDetector from mmrazor.registry import TASK_UTILS +try: + from mmdet.models.detectors import BaseDetector +except ImportError: + from mmrazor.utils import get_placeholder + BaseDetector = get_placeholder('mmdet') + # todo: adapt to mmdet 2.0 @TASK_UTILS.register_module() diff --git a/tests/test_runners/test_utils/test_check.py b/tests/test_runners/test_utils/test_check.py index 31808fedf..b9bd57989 100644 --- a/tests/test_runners/test_utils/test_check.py +++ b/tests/test_runners/test_utils/test_check.py @@ -1,15 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmrazor.engine.runner.utils import check_subnet_flops -from mmdet.models.detectors import BaseDetector from unittest.mock import patch +from mmrazor.engine.runner.utils import check_subnet_flops + +try: + from mmdet.models.detectors import BaseDetector +except ImportError: + from mmrazor.utils import get_placeholder + BaseDetector = get_placeholder('mmdet') + + @patch('mmrazor.models.ResourceEstimator') @patch('mmrazor.models.SPOS') def test_check_subnet_flops(mock_model, mock_estimator): # flops_range = None flops_range = None fake_subnet = {'1': 'choice1', '2': 'choice2'} - result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, flops_range) + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, + flops_range) assert result is True # flops_range is not None @@ -18,7 +26,8 @@ def test_check_subnet_flops(mock_model, mock_estimator): mock_model.architecture = BaseDetector fake_results = {'flops': 50.} mock_estimator.estimate.return_value = fake_results - result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, flops_range) + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, + flops_range) assert result is True # flops_range is not None @@ -26,6 +35,6 @@ def test_check_subnet_flops(mock_model, mock_estimator): flops_range = (0., 100.) fake_results = {'flops': -50.} mock_estimator.estimate.return_value = fake_results - result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, flops_range) + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, + flops_range) assert result is False -