Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add batch normalization folding to QAT quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 committed Jul 8, 2021
1 parent 507595b commit 555f1a0
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 14 deletions.
48 changes: 41 additions & 7 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class QAT_Quantizer(Quantizer):
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer=None, model_inputs=None):
"""
Parameters
----------
Expand All @@ -144,8 +144,11 @@ def __init__(self, model, config_list, optimizer=None):
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
- model_inputs : tuple of tensor
inputs to the model, which are used to get the graph of the module
"""
super().__init__(model, config_list, optimizer)

super().__init__(model, config_list, optimizer, model_inputs)
self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device
Expand All @@ -168,8 +171,9 @@ def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation',
'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit',
'activation_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
Expand Down Expand Up @@ -342,9 +346,39 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_

return calibration_config

def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
def fold_bn(self, *inputs, wrapper):
"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module = wrapper.module
bn_module = wrapper.bn_module
with torch.no_grad():
output = module(*inputs)
_ = bn_module(output)
running_mean = bn_module.running_mean
running_var = torch.sqrt(bn_module.running_var + 1e-10)
bn_weight = bn_module.weight
bn_bias = bn_module.bias
dimensions = len(module.weight.shape)
shape = [-1] + [1] * (dimensions - 1)
new_weight = module.old_weight * bn_weight.reshape(shape) / running_var.reshape(shape)
if hasattr(module, 'old_bias'):
new_bias = bn_bias + (module.old_bias - running_mean) / running_var * bn_weight
else:
new_bias = bn_bias - running_mean / running_var * bn_weight
return new_weight, new_bias

def step_with_optimizer(self):
"""
Expand Down
123 changes: 116 additions & 7 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import types
import logging
import torch
from nni.common.graph_utils import build_module_graph
from . import default_layers

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -463,7 +464,7 @@ def get_pruned_weights(self, dim=0):


class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
def __init__(self, module, module_name, module_type, config, quantizer, bn_module=None):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Expand All @@ -479,6 +480,8 @@ def __init__(self, module, module_name, module_type, config, quantizer):
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
bn_module : torch.nn.Module
batch norm layer corresponding to current module, used for simulating batch normalization folding
"""
super().__init__()
# origin layer information
Expand All @@ -488,6 +491,7 @@ def __init__(self, module, module_name, module_type, config, quantizer):
# config and pruner
self.config = config
self.quantizer = quantizer
self.bn_module = bn_module

# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
Expand All @@ -501,6 +505,17 @@ def __init__(self, module, module_name, module_type, config, quantizer):
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)

# for batch normalization folding
if self.bn_module is not None:
if _check_bias(self.module):
self.module.register_parameter('old_bias', torch.nn.Parameter(self.module.bias))
init_tensor = self.module.old_bias
else:
init_tensor = torch.zeros_like(self.bn_module.weight)
delattr(self.module, 'bias')
self.module.register_buffer('bias', init_tensor)


def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad(
Expand All @@ -509,13 +524,19 @@ def forward(self, *inputs):
self)

if 'weight' in self.config['quant_types'] and _check_weight(self.module):
if self.bn_module is not None:
# simulate batch normalization folding
new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self)
self.module.bias = new_bias
else:
new_weight = self.module.old_weight

self.quantizer.quant_grad(
self.module.old_weight,
new_weight,
QuantType.QUANT_WEIGHT,
self, inputs[0])
result = self.module(*inputs)
else:
result = self.module(*inputs)

result = self.module(*inputs)

if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad(
Expand All @@ -525,12 +546,35 @@ def forward(self, *inputs):
return result


class QuantizerIdentityWrapper(torch.nn.Module):
def __init__(self, module, module_name):
"""
Used to wrap modules that should be treated as torch.Identity
Parameters
----------
module : pytorch module
the module to be wrapped
module_name : str
the name of the module to wrapped, wrapper module shares same name
"""
super().__init__()
self.module = module
self.module_name = module_name

def forward(self, x):
return x


class Quantizer(Compressor):
"""
Base quantizer for pytorch quantizer
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer=None, model_inputs=None):
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, model_inputs)
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad.apply
if self.optimizer is not None:
Expand All @@ -540,6 +584,10 @@ def __init__(self, model, config_list, optimizer=None):
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
self.optimizer.add_param_group({"params": wrapper.module.old_weight})
# This is for conv with bias + bn. Although this situation is relatively rare,
# we still need to deal with the old_bias when it occurs
if hasattr(wrapper.module, "old_bias"):
self.optimizer.add_param_group({"params": getattr(wrapper.module, "old_bias")})

def quantize_weight(self, wrapper, **kwargs):
"""
Expand Down Expand Up @@ -597,7 +645,36 @@ def _wrap_modules(self, layer, config):
for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type

return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
# bound bn module to corresponding conv module
bn_module = None
if layer.name in self.conv_bn_patterns:
bn_module_name = self.conv_bn_patterns[layer.name]
for name, module in self.bound_model.named_modules():
if name == bn_module_name:
bn_module = module
break
assert bn_module is not None, "BN module corresponding to layer {} is not found".format(layer.name)
self.identity_wrappers.append(QuantizerIdentityWrapper(bn_module, bn_module_name))
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self, bn_module)

def _wrap_model(self):
"""
wrap all modules that needed to be compressed
"""
# wrap folded bn in order to bypass its forward process
for wrapper in reversed(self.identity_wrappers):
_setattr(self.bound_model, wrapper.module_name, wrapper)
super()._wrap_model()

def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed
"""
for wrapper in self.identity_wrappers:
_setattr(self.bound_model, wrapper.module_name, wrapper.module)
super()._unwrap_model()

def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None,
input_shape=None, device=None):
Expand Down Expand Up @@ -660,6 +737,30 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
"""
raise NotImplementedError('Quantizer must overload export_model()')

def find_conv_bn_patterns(self, model, model_inputs):
"""
Find all Conv-BN patterns, used for batch normalization folding
Parameters
----------
model : torch.nn.Module
model to be analyzed.
model_inputs : tupel of torch.tensor
inputs to the model, used for generating the torchscript
"""
if model_inputs is None:
_logger.debug("Model inputs are not given, batch normalization folding is disabled")
return

graph = build_module_graph(model, model_inputs)
for node_group in graph.nodes_py.nodes_op:
if node_group.op_type in BN_FOLD_OP:
successors = graph.find_successors(node_group.unique_name)
successors = [graph.name_to_node[x] for x in successors]
for successor in successors:
if successor.op_type == 'BatchNorm2d':
self.conv_bn_patterns[node_group.name] = successor.name

def step_with_optimizer(self):
pass

Expand All @@ -677,6 +778,8 @@ class QuantType:
2: "output"
}

BN_FOLD_OP = ["Conv2d"]

class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
Expand Down Expand Up @@ -773,6 +876,12 @@ def _check_weight(module):
except AttributeError:
return False

def _check_bias(module):
try:
return isinstance(module.bias.data, torch.Tensor)
except AttributeError:
return False

def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs)
Expand Down

0 comments on commit 555f1a0

Please sign in to comment.