From fbce790c593ea9e91b3ebd4cb67c6e8797440cdb Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Wed, 7 Jul 2021 10:20:28 +0800 Subject: [PATCH 1/9] Add batch normalization folding to QAT quantizer --- .../pytorch/quantization/quantizers.py | 48 ++++++- nni/compression/pytorch/compressor.py | 123 +++++++++++++++++- 2 files changed, 157 insertions(+), 14 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index dca1ef778e..14e2af24fe 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer): http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf """ - def __init__(self, model, config_list, optimizer=None): + def __init__(self, model, config_list, optimizer=None, model_inputs=None): """ Parameters ---------- @@ -145,8 +145,11 @@ def __init__(self, model, config_list, optimizer=None): state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 - op_types : list of string types of nn.module you want to apply quantization, eg. 'Conv2d' + - model_inputs : tuple of tensor + inputs to the model, which are used to get the graph of the module """ - super().__init__(model, config_list, optimizer) + + super().__init__(model, config_list, optimizer, model_inputs) self.quant_grad = QATGrad.apply modules_to_compress = self.get_modules_to_compress() device = next(model.parameters()).device @@ -169,8 +172,9 @@ def _del_simulated_attr(self, module): """ delete redundant parameters in quantize module """ - del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \ - 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] + del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', + 'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', + 'activation_bit'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr) @@ -344,9 +348,39 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ return calibration_config - def fold_bn(self, config, **kwargs): - # TODO simulate folded weight - pass + def fold_bn(self, *inputs, wrapper): + """ + Simulate batch normalization folding in the training graph. Folded weight and bias are + returned for the following operations. + + Parameters + ---------- + inputs : tuple of torch.Tensor + inputs for the module + wrapper : QuantizerModuleWrapper + the wrapper for origin module + + Returns + ------- + Tuple of torch.Tensor + """ + module = wrapper.module + bn_module = wrapper.bn_module + with torch.no_grad(): + output = module(*inputs) + _ = bn_module(output) + running_mean = bn_module.running_mean + running_var = torch.sqrt(bn_module.running_var + 1e-10) + bn_weight = bn_module.weight + bn_bias = bn_module.bias + dimensions = len(module.weight.shape) + shape = [-1] + [1] * (dimensions - 1) + new_weight = module.old_weight * bn_weight.reshape(shape) / running_var.reshape(shape) + if hasattr(module, 'old_bias'): + new_bias = bn_bias + (module.old_bias - running_mean) / running_var * bn_weight + else: + new_bias = bn_bias - running_mean / running_var * bn_weight + return new_weight, new_bias def step_with_optimizer(self): """ diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index 01b8bb24e4..c93600b39d 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -4,6 +4,7 @@ import types import logging import torch +from nni.common.graph_utils import build_module_graph from . import default_layers _logger = logging.getLogger(__name__) @@ -463,7 +464,7 @@ def get_pruned_weights(self, dim=0): class QuantizerModuleWrapper(torch.nn.Module): - def __init__(self, module, module_name, module_type, config, quantizer): + def __init__(self, module, module_name, module_type, config, quantizer, bn_module=None): """ Wrap an module to enable data parallel, forward method customization and buffer registeration. @@ -479,6 +480,8 @@ def __init__(self, module, module_name, module_type, config, quantizer): the type of the module to compress quantizer :quantizer the quantizer used to calculate mask + bn_module : torch.nn.Module + batch norm layer corresponding to current module, used for simulating batch normalization folding """ super().__init__() # origin layer information @@ -488,6 +491,7 @@ def __init__(self, module, module_name, module_type, config, quantizer): # config and pruner self.config = config self.quantizer = quantizer + self.bn_module = bn_module # register buffer and parameter # old_weight is used to store origin weight and weight is used to store quantized weight @@ -501,6 +505,17 @@ def __init__(self, module, module_name, module_type, config, quantizer): delattr(self.module, 'weight') self.module.register_buffer('weight', self.module.old_weight) + # for batch normalization folding + if self.bn_module is not None: + if _check_bias(self.module): + self.module.register_parameter('old_bias', torch.nn.Parameter(self.module.bias)) + init_tensor = self.module.old_bias + else: + init_tensor = torch.zeros_like(self.bn_module.weight) + delattr(self.module, 'bias') + self.module.register_buffer('bias', init_tensor) + + def forward(self, *inputs): if 'input' in self.config['quant_types']: inputs = self.quantizer.quant_grad( @@ -509,13 +524,19 @@ def forward(self, *inputs): self) if 'weight' in self.config['quant_types'] and _check_weight(self.module): + if self.bn_module is not None: + # simulate batch normalization folding + new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self) + self.module.bias = new_bias + else: + new_weight = self.module.old_weight + self.quantizer.quant_grad( - self.module.old_weight, + new_weight, QuantType.QUANT_WEIGHT, self, inputs[0]) - result = self.module(*inputs) - else: - result = self.module(*inputs) + + result = self.module(*inputs) if 'output' in self.config['quant_types']: result = self.quantizer.quant_grad( @@ -525,12 +546,35 @@ def forward(self, *inputs): return result +class QuantizerIdentityWrapper(torch.nn.Module): + def __init__(self, module, module_name): + """ + Used to wrap modules that should be treated as torch.Identity + + Parameters + ---------- + module : pytorch module + the module to be wrapped + module_name : str + the name of the module to wrapped, wrapper module shares same name + """ + super().__init__() + self.module = module + self.module_name = module_name + + def forward(self, x): + return x + + class Quantizer(Compressor): """ Base quantizer for pytorch quantizer """ - def __init__(self, model, config_list, optimizer=None): + def __init__(self, model, config_list, optimizer=None, model_inputs=None): + self.identity_wrappers = [] + self.conv_bn_patterns = {} + self.find_conv_bn_patterns(model, model_inputs) super().__init__(model, config_list, optimizer) self.quant_grad = QuantGrad.apply if self.optimizer is not None: @@ -540,6 +584,10 @@ def __init__(self, model, config_list, optimizer=None): # old_weight is registered to keep track of weight before quantization # and it is trainable, therefore, it should be added to optimizer. self.optimizer.add_param_group({"params": wrapper.module.old_weight}) + # This is for conv with bias + bn. Although this situation is relatively rare, + # we still need to deal with the old_bias when it occurs + if hasattr(wrapper.module, "old_bias"): + self.optimizer.add_param_group({"params": getattr(wrapper.module, "old_bias")}) def quantize_weight(self, wrapper, **kwargs): """ @@ -597,7 +645,36 @@ def _wrap_modules(self, layer, config): for quant_type in config['quant_types']: assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type - return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self) + # bound bn module to corresponding conv module + bn_module = None + if layer.name in self.conv_bn_patterns: + bn_module_name = self.conv_bn_patterns[layer.name] + for name, module in self.bound_model.named_modules(): + if name == bn_module_name: + bn_module = module + break + assert bn_module is not None, "BN module corresponding to layer {} is not found".format(layer.name) + self.identity_wrappers.append(QuantizerIdentityWrapper(bn_module, bn_module_name)) + return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self, bn_module) + + def _wrap_model(self): + """ + wrap all modules that needed to be compressed + + """ + # wrap folded bn in order to bypass its forward process + for wrapper in reversed(self.identity_wrappers): + _setattr(self.bound_model, wrapper.module_name, wrapper) + super()._wrap_model() + + def _unwrap_model(self): + """ + unwrap all modules that needed to be compressed + + """ + for wrapper in self.identity_wrappers: + _setattr(self.bound_model, wrapper.module_name, wrapper.module) + super()._unwrap_model() def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, input_shape=None, device=None): @@ -660,6 +737,30 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ """ raise NotImplementedError('Quantizer must overload export_model()') + def find_conv_bn_patterns(self, model, model_inputs): + """ + Find all Conv-BN patterns, used for batch normalization folding + + Parameters + ---------- + model : torch.nn.Module + model to be analyzed. + model_inputs : tupel of torch.tensor + inputs to the model, used for generating the torchscript + """ + if model_inputs is None: + _logger.debug("Model inputs are not given, batch normalization folding is disabled") + return + + graph = build_module_graph(model, model_inputs) + for node_group in graph.nodes_py.nodes_op: + if node_group.op_type in BN_FOLD_OP: + successors = graph.find_successors(node_group.unique_name) + successors = [graph.name_to_node[x] for x in successors] + for successor in successors: + if successor.op_type == 'BatchNorm2d': + self.conv_bn_patterns[node_group.name] = successor.name + def step_with_optimizer(self): pass @@ -677,6 +778,8 @@ class QuantType: 2: "output" } +BN_FOLD_OP = ["Conv2d"] + class QuantGrad(torch.autograd.Function): """ Base class for overriding backward function of quantization operation. @@ -773,6 +876,12 @@ def _check_weight(module): except AttributeError: return False +def _check_bias(module): + try: + return isinstance(module.bias.data, torch.Tensor) + except AttributeError: + return False + def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): if quant_type == QuantType.QUANT_INPUT: output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs) From b2ce9baf3f12212542b0ab7573817d2c480f7558 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 12 Jul 2021 17:13:00 +0800 Subject: [PATCH 2/9] refine --- .../pytorch/quantization/quantizers.py | 29 +++++++++++++++---- nni/compression/pytorch/compressor.py | 16 +++++----- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 14e2af24fe..b0dc88c7ec 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -6,7 +6,7 @@ import torch from schema import Schema, And, Or, Optional from nni.compression.pytorch.utils.config_validation import QuantizerSchema -from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType +from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] @@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer): http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf """ - def __init__(self, model, config_list, optimizer=None, model_inputs=None): + def __init__(self, model, config_list, optimizer=None, dummy_input=None): """ Parameters ---------- @@ -145,11 +145,11 @@ def __init__(self, model, config_list, optimizer=None, model_inputs=None): state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 - op_types : list of string types of nn.module you want to apply quantization, eg. 'Conv2d' - - model_inputs : tuple of tensor + - dummy_input : tuple of tensor inputs to the model, which are used to get the graph of the module """ - super().__init__(model, config_list, optimizer, model_inputs) + super().__init__(model, config_list, optimizer, dummy_input) self.quant_grad = QATGrad.apply modules_to_compress = self.get_modules_to_compress() device = next(model.parameters()).device @@ -174,7 +174,7 @@ def _del_simulated_attr(self, module): """ del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', - 'activation_bit'] + 'activation_bit', 'BN_FOLD_TAG'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr) @@ -338,6 +338,23 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config[name]['weight_bit'] = int(module.weight_bit) calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input) calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input) + + # Recover weight/bias for batch normalization folding + if hasattr(module, BN_FOLD_TAG): + actual_weight = getattr(module, 'old_weight', None) + if actual_weight is None: + logger.warning("Can not recover weight for layer {}. " + "This may lead to a wrong accuracy performance on the backend.".format(name)) + delattr(module, 'weight') + module.register_parameter('weight', actual_weight) + + actual_bias = getattr(module, 'old_bias', None) + delattr(module, 'bias') + if actual_bias is not None: + module.register_parameter('bias', actual_bias) + else: + setattr(module, 'bias', None) + if hasattr(module, 'activation_bit'): calibration_config[name]['activation_bit'] = int(module.activation_bit) calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation) @@ -370,7 +387,7 @@ def fold_bn(self, *inputs, wrapper): output = module(*inputs) _ = bn_module(output) running_mean = bn_module.running_mean - running_var = torch.sqrt(bn_module.running_var + 1e-10) + running_var = torch.sqrt(bn_module.running_var + bn_module.eps) bn_weight = bn_module.weight bn_bias = bn_module.bias dimensions = len(module.weight.shape) diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index c93600b39d..4dcde45422 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -514,7 +514,7 @@ def __init__(self, module, module_name, module_type, config, quantizer, bn_modul init_tensor = torch.zeros_like(self.bn_module.weight) delattr(self.module, 'bias') self.module.register_buffer('bias', init_tensor) - + setattr(module, BN_FOLD_TAG, True) def forward(self, *inputs): if 'input' in self.config['quant_types']: @@ -528,6 +528,7 @@ def forward(self, *inputs): # simulate batch normalization folding new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self) self.module.bias = new_bias + self.module.weight = new_weight else: new_weight = self.module.old_weight @@ -571,10 +572,10 @@ class Quantizer(Compressor): Base quantizer for pytorch quantizer """ - def __init__(self, model, config_list, optimizer=None, model_inputs=None): + def __init__(self, model, config_list, optimizer=None, dummy_input=None): self.identity_wrappers = [] self.conv_bn_patterns = {} - self.find_conv_bn_patterns(model, model_inputs) + self.find_conv_bn_patterns(model, dummy_input) super().__init__(model, config_list, optimizer) self.quant_grad = QuantGrad.apply if self.optimizer is not None: @@ -737,7 +738,7 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ """ raise NotImplementedError('Quantizer must overload export_model()') - def find_conv_bn_patterns(self, model, model_inputs): + def find_conv_bn_patterns(self, model, dummy_input): """ Find all Conv-BN patterns, used for batch normalization folding @@ -745,14 +746,14 @@ def find_conv_bn_patterns(self, model, model_inputs): ---------- model : torch.nn.Module model to be analyzed. - model_inputs : tupel of torch.tensor + dummy_input : tupel of torch.tensor inputs to the model, used for generating the torchscript """ - if model_inputs is None: + if dummy_input is None: _logger.debug("Model inputs are not given, batch normalization folding is disabled") return - graph = build_module_graph(model, model_inputs) + graph = build_module_graph(model, dummy_input) for node_group in graph.nodes_py.nodes_op: if node_group.op_type in BN_FOLD_OP: successors = graph.find_successors(node_group.unique_name) @@ -779,6 +780,7 @@ class QuantType: } BN_FOLD_OP = ["Conv2d"] +BN_FOLD_TAG = 'BN_FOLD_TAG' class QuantGrad(torch.autograd.Function): """ From de4fbee70a4cb507bd33d33f1b69c7008c29ba20 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 12 Jul 2021 19:12:58 +0800 Subject: [PATCH 3/9] fix linter --- nni/algorithms/compression/pytorch/quantization/quantizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index b0dc88c7ec..3b5c0ac752 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -343,8 +343,8 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ if hasattr(module, BN_FOLD_TAG): actual_weight = getattr(module, 'old_weight', None) if actual_weight is None: - logger.warning("Can not recover weight for layer {}. " - "This may lead to a wrong accuracy performance on the backend.".format(name)) + logger.warning("Can not recover weight for layer %s. " + "This may lead to a wrong accuracy performance on the backend.", name) delattr(module, 'weight') module.register_parameter('weight', actual_weight) From 408e9b253a7a387ab3ec461dacc9ab07977e16bf Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Sun, 18 Jul 2021 15:37:25 +0800 Subject: [PATCH 4/9] update docs --- nni/algorithms/compression/pytorch/quantization/quantizers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 3b5c0ac752..0e7ff11e96 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -146,7 +146,9 @@ def __init__(self, model, config_list, optimizer=None, dummy_input=None): - op_types : list of string types of nn.module you want to apply quantization, eg. 'Conv2d' - dummy_input : tuple of tensor - inputs to the model, which are used to get the graph of the module + inputs to the model, which are used to get the graph of the module. The graph is used to find + Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not + given, then batch normalization folding would be disabled. """ super().__init__(model, config_list, optimizer, dummy_input) From 7368b5891c07cb4be39ba35600447016b5bcbdf4 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Sun, 18 Jul 2021 15:42:12 +0800 Subject: [PATCH 5/9] update docs --- nni/algorithms/compression/pytorch/quantization/quantizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 0e7ff11e96..09e53329b0 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -148,7 +148,7 @@ def __init__(self, model, config_list, optimizer=None, dummy_input=None): - dummy_input : tuple of tensor inputs to the model, which are used to get the graph of the module. The graph is used to find Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not - given, then batch normalization folding would be disabled. + given, the batch normalization folding would be disabled. """ super().__init__(model, config_list, optimizer, dummy_input) From 16f702f4065f1dd4df291dc18d4fd9e8ddf51559 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 19 Jul 2021 11:27:37 +0800 Subject: [PATCH 6/9] update docs --- docs/en_US/Compression/Quantizer.rst | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/docs/en_US/Compression/Quantizer.rst b/docs/en_US/Compression/Quantizer.rst index 073d0cd961..aff625d721 100644 --- a/docs/en_US/Compression/Quantizer.rst +++ b/docs/en_US/Compression/Quantizer.rst @@ -82,10 +82,25 @@ configuration needed by this algorithm : disable quantization until model are run by certain number of steps, this allows the network to enter a more stable state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 -note -^^^^ +Batch normalization folding +^^^^^^^^^^^^^^^^^^^^^^^^^^^ -batch normalization folding is currently not supported. +Batch normalization folding is supported in QAT quantizer. It can be easily enabled by passing an argument `dummy_input` to +the quantizer, like: + +.. code-block:: python + + # assume your model takes an input of shape (1, 1, 28, 28) + # and dummy_input must be on the same device as the model + dummy_input = torch.randn(1, 1, 28, 28) + + # pass the dummy_input to the quantizer + quantizer = QAT_Quantizer(model, config_list, dummy_input=dummy_input) + + +The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training +graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling +`quantizer.export`. ---- From 16fe54cd650a23250d5b8266a80f1281a46d5889 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 19 Jul 2021 11:37:40 +0800 Subject: [PATCH 7/9] update docs --- docs/en_US/Compression/Quantizer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/Compression/Quantizer.rst b/docs/en_US/Compression/Quantizer.rst index aff625d721..8b5ffd99f9 100644 --- a/docs/en_US/Compression/Quantizer.rst +++ b/docs/en_US/Compression/Quantizer.rst @@ -100,7 +100,7 @@ the quantizer, like: The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling -`quantizer.export`. +`quantizer.export_model`. ---- From 12af635a00b7d8957ffd61390d4667867293f74a Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 19 Jul 2021 12:29:59 +0800 Subject: [PATCH 8/9] update docs --- docs/en_US/Compression/Quantizer.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/en_US/Compression/Quantizer.rst b/docs/en_US/Compression/Quantizer.rst index 8b5ffd99f9..1dd3ab16f0 100644 --- a/docs/en_US/Compression/Quantizer.rst +++ b/docs/en_US/Compression/Quantizer.rst @@ -97,7 +97,6 @@ the quantizer, like: # pass the dummy_input to the quantizer quantizer = QAT_Quantizer(model, config_list, dummy_input=dummy_input) - The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling `quantizer.export_model`. From e4d72380279c67667a094f4823d8b28a1a950dbe Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 19 Jul 2021 14:30:39 +0800 Subject: [PATCH 9/9] update docs --- docs/en_US/Compression/Quantizer.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/en_US/Compression/Quantizer.rst b/docs/en_US/Compression/Quantizer.rst index 1dd3ab16f0..8b5ffd99f9 100644 --- a/docs/en_US/Compression/Quantizer.rst +++ b/docs/en_US/Compression/Quantizer.rst @@ -97,6 +97,7 @@ the quantizer, like: # pass the dummy_input to the quantizer quantizer = QAT_Quantizer(model, config_list, dummy_input=dummy_input) + The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling `quantizer.export_model`.