diff --git a/docs/en_US/Compression/Quantizer.rst b/docs/en_US/Compression/Quantizer.rst index 073d0cd961..8b5ffd99f9 100644 --- a/docs/en_US/Compression/Quantizer.rst +++ b/docs/en_US/Compression/Quantizer.rst @@ -82,10 +82,25 @@ configuration needed by this algorithm : disable quantization until model are run by certain number of steps, this allows the network to enter a more stable state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 -note -^^^^ +Batch normalization folding +^^^^^^^^^^^^^^^^^^^^^^^^^^^ -batch normalization folding is currently not supported. +Batch normalization folding is supported in QAT quantizer. It can be easily enabled by passing an argument `dummy_input` to +the quantizer, like: + +.. code-block:: python + + # assume your model takes an input of shape (1, 1, 28, 28) + # and dummy_input must be on the same device as the model + dummy_input = torch.randn(1, 1, 28, 28) + + # pass the dummy_input to the quantizer + quantizer = QAT_Quantizer(model, config_list, dummy_input=dummy_input) + + +The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training +graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling +`quantizer.export_model`. ---- diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index dca1ef778e..09e53329b0 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -6,7 +6,7 @@ import torch from schema import Schema, And, Or, Optional from nni.compression.pytorch.utils.config_validation import QuantizerSchema -from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType +from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] @@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer): http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf """ - def __init__(self, model, config_list, optimizer=None): + def __init__(self, model, config_list, optimizer=None, dummy_input=None): """ Parameters ---------- @@ -145,8 +145,13 @@ def __init__(self, model, config_list, optimizer=None): state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 - op_types : list of string types of nn.module you want to apply quantization, eg. 'Conv2d' + - dummy_input : tuple of tensor + inputs to the model, which are used to get the graph of the module. The graph is used to find + Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not + given, the batch normalization folding would be disabled. """ - super().__init__(model, config_list, optimizer) + + super().__init__(model, config_list, optimizer, dummy_input) self.quant_grad = QATGrad.apply modules_to_compress = self.get_modules_to_compress() device = next(model.parameters()).device @@ -169,8 +174,9 @@ def _del_simulated_attr(self, module): """ delete redundant parameters in quantize module """ - del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \ - 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] + del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', + 'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', + 'activation_bit', 'BN_FOLD_TAG'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr) @@ -334,6 +340,23 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config[name]['weight_bit'] = int(module.weight_bit) calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input) calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input) + + # Recover weight/bias for batch normalization folding + if hasattr(module, BN_FOLD_TAG): + actual_weight = getattr(module, 'old_weight', None) + if actual_weight is None: + logger.warning("Can not recover weight for layer %s. " + "This may lead to a wrong accuracy performance on the backend.", name) + delattr(module, 'weight') + module.register_parameter('weight', actual_weight) + + actual_bias = getattr(module, 'old_bias', None) + delattr(module, 'bias') + if actual_bias is not None: + module.register_parameter('bias', actual_bias) + else: + setattr(module, 'bias', None) + if hasattr(module, 'activation_bit'): calibration_config[name]['activation_bit'] = int(module.activation_bit) calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation) @@ -344,9 +367,39 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ return calibration_config - def fold_bn(self, config, **kwargs): - # TODO simulate folded weight - pass + def fold_bn(self, *inputs, wrapper): + """ + Simulate batch normalization folding in the training graph. Folded weight and bias are + returned for the following operations. + + Parameters + ---------- + inputs : tuple of torch.Tensor + inputs for the module + wrapper : QuantizerModuleWrapper + the wrapper for origin module + + Returns + ------- + Tuple of torch.Tensor + """ + module = wrapper.module + bn_module = wrapper.bn_module + with torch.no_grad(): + output = module(*inputs) + _ = bn_module(output) + running_mean = bn_module.running_mean + running_var = torch.sqrt(bn_module.running_var + bn_module.eps) + bn_weight = bn_module.weight + bn_bias = bn_module.bias + dimensions = len(module.weight.shape) + shape = [-1] + [1] * (dimensions - 1) + new_weight = module.old_weight * bn_weight.reshape(shape) / running_var.reshape(shape) + if hasattr(module, 'old_bias'): + new_bias = bn_bias + (module.old_bias - running_mean) / running_var * bn_weight + else: + new_bias = bn_bias - running_mean / running_var * bn_weight + return new_weight, new_bias def step_with_optimizer(self): """ diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index 01b8bb24e4..4dcde45422 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -4,6 +4,7 @@ import types import logging import torch +from nni.common.graph_utils import build_module_graph from . import default_layers _logger = logging.getLogger(__name__) @@ -463,7 +464,7 @@ def get_pruned_weights(self, dim=0): class QuantizerModuleWrapper(torch.nn.Module): - def __init__(self, module, module_name, module_type, config, quantizer): + def __init__(self, module, module_name, module_type, config, quantizer, bn_module=None): """ Wrap an module to enable data parallel, forward method customization and buffer registeration. @@ -479,6 +480,8 @@ def __init__(self, module, module_name, module_type, config, quantizer): the type of the module to compress quantizer :quantizer the quantizer used to calculate mask + bn_module : torch.nn.Module + batch norm layer corresponding to current module, used for simulating batch normalization folding """ super().__init__() # origin layer information @@ -488,6 +491,7 @@ def __init__(self, module, module_name, module_type, config, quantizer): # config and pruner self.config = config self.quantizer = quantizer + self.bn_module = bn_module # register buffer and parameter # old_weight is used to store origin weight and weight is used to store quantized weight @@ -501,6 +505,17 @@ def __init__(self, module, module_name, module_type, config, quantizer): delattr(self.module, 'weight') self.module.register_buffer('weight', self.module.old_weight) + # for batch normalization folding + if self.bn_module is not None: + if _check_bias(self.module): + self.module.register_parameter('old_bias', torch.nn.Parameter(self.module.bias)) + init_tensor = self.module.old_bias + else: + init_tensor = torch.zeros_like(self.bn_module.weight) + delattr(self.module, 'bias') + self.module.register_buffer('bias', init_tensor) + setattr(module, BN_FOLD_TAG, True) + def forward(self, *inputs): if 'input' in self.config['quant_types']: inputs = self.quantizer.quant_grad( @@ -509,13 +524,20 @@ def forward(self, *inputs): self) if 'weight' in self.config['quant_types'] and _check_weight(self.module): + if self.bn_module is not None: + # simulate batch normalization folding + new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self) + self.module.bias = new_bias + self.module.weight = new_weight + else: + new_weight = self.module.old_weight + self.quantizer.quant_grad( - self.module.old_weight, + new_weight, QuantType.QUANT_WEIGHT, self, inputs[0]) - result = self.module(*inputs) - else: - result = self.module(*inputs) + + result = self.module(*inputs) if 'output' in self.config['quant_types']: result = self.quantizer.quant_grad( @@ -525,12 +547,35 @@ def forward(self, *inputs): return result +class QuantizerIdentityWrapper(torch.nn.Module): + def __init__(self, module, module_name): + """ + Used to wrap modules that should be treated as torch.Identity + + Parameters + ---------- + module : pytorch module + the module to be wrapped + module_name : str + the name of the module to wrapped, wrapper module shares same name + """ + super().__init__() + self.module = module + self.module_name = module_name + + def forward(self, x): + return x + + class Quantizer(Compressor): """ Base quantizer for pytorch quantizer """ - def __init__(self, model, config_list, optimizer=None): + def __init__(self, model, config_list, optimizer=None, dummy_input=None): + self.identity_wrappers = [] + self.conv_bn_patterns = {} + self.find_conv_bn_patterns(model, dummy_input) super().__init__(model, config_list, optimizer) self.quant_grad = QuantGrad.apply if self.optimizer is not None: @@ -540,6 +585,10 @@ def __init__(self, model, config_list, optimizer=None): # old_weight is registered to keep track of weight before quantization # and it is trainable, therefore, it should be added to optimizer. self.optimizer.add_param_group({"params": wrapper.module.old_weight}) + # This is for conv with bias + bn. Although this situation is relatively rare, + # we still need to deal with the old_bias when it occurs + if hasattr(wrapper.module, "old_bias"): + self.optimizer.add_param_group({"params": getattr(wrapper.module, "old_bias")}) def quantize_weight(self, wrapper, **kwargs): """ @@ -597,7 +646,36 @@ def _wrap_modules(self, layer, config): for quant_type in config['quant_types']: assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type - return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self) + # bound bn module to corresponding conv module + bn_module = None + if layer.name in self.conv_bn_patterns: + bn_module_name = self.conv_bn_patterns[layer.name] + for name, module in self.bound_model.named_modules(): + if name == bn_module_name: + bn_module = module + break + assert bn_module is not None, "BN module corresponding to layer {} is not found".format(layer.name) + self.identity_wrappers.append(QuantizerIdentityWrapper(bn_module, bn_module_name)) + return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self, bn_module) + + def _wrap_model(self): + """ + wrap all modules that needed to be compressed + + """ + # wrap folded bn in order to bypass its forward process + for wrapper in reversed(self.identity_wrappers): + _setattr(self.bound_model, wrapper.module_name, wrapper) + super()._wrap_model() + + def _unwrap_model(self): + """ + unwrap all modules that needed to be compressed + + """ + for wrapper in self.identity_wrappers: + _setattr(self.bound_model, wrapper.module_name, wrapper.module) + super()._unwrap_model() def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, input_shape=None, device=None): @@ -660,6 +738,30 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ """ raise NotImplementedError('Quantizer must overload export_model()') + def find_conv_bn_patterns(self, model, dummy_input): + """ + Find all Conv-BN patterns, used for batch normalization folding + + Parameters + ---------- + model : torch.nn.Module + model to be analyzed. + dummy_input : tupel of torch.tensor + inputs to the model, used for generating the torchscript + """ + if dummy_input is None: + _logger.debug("Model inputs are not given, batch normalization folding is disabled") + return + + graph = build_module_graph(model, dummy_input) + for node_group in graph.nodes_py.nodes_op: + if node_group.op_type in BN_FOLD_OP: + successors = graph.find_successors(node_group.unique_name) + successors = [graph.name_to_node[x] for x in successors] + for successor in successors: + if successor.op_type == 'BatchNorm2d': + self.conv_bn_patterns[node_group.name] = successor.name + def step_with_optimizer(self): pass @@ -677,6 +779,9 @@ class QuantType: 2: "output" } +BN_FOLD_OP = ["Conv2d"] +BN_FOLD_TAG = 'BN_FOLD_TAG' + class QuantGrad(torch.autograd.Function): """ Base class for overriding backward function of quantization operation. @@ -773,6 +878,12 @@ def _check_weight(module): except AttributeError: return False +def _check_bias(module): + try: + return isinstance(module.bias.data, torch.Tensor) + except AttributeError: + return False + def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): if quant_type == QuantType.QUANT_INPUT: output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs)