Skip to content

Commit

Permalink
Update 20230629
Browse files Browse the repository at this point in the history
* 修复了 trainable graph 上的错误
* 添加了一个类似 adaround 的优化过程
* 添加了 weight equalization 参数
* 修复了 program entrance 2 的相关错误
  • Loading branch information
ZhangZhiPku committed Jun 29, 2023
1 parent 201d31c commit 1dcd585
Show file tree
Hide file tree
Showing 10 changed files with 609 additions and 43 deletions.
13 changes: 3 additions & 10 deletions ProgramEntrance_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def init_quantize_config(self, operation: Operation) -> OperationQuantizationCon

return OQC

@ property
def quant_operation_types(self) -> set:
return {'Conv', 'ConvTranspose', 'MatMul', 'Gemm',
'Relu', 'Clip', 'Sub', 'Abs', 'Mul',
Expand All @@ -178,9 +179,8 @@ class MyOptimPass(QuantizationOptimizationPass):
This Optimization Pass will:
1. fuse relu - clip structure.
2. set clip output scale in the network to 1/127.
3. exclude input variables in the network from quantization.
4. set the input and output quantization information of the abs operators to be the same.
5. modify calibration method for some operators.
3. set the input and output quantization information of the abs operators to be the same.
4. modify calibration method for some operators.
"""
def __init__(self, name: str = 'My Optim Pass') -> None:
super().__init__(name)
Expand All @@ -204,13 +204,6 @@ def optimize(self, graph: BaseGraph, **kwargs) -> None:
clip.config.output_quantization_config[0].offset = torch.tensor(0.0).cuda()
clip.config.output_quantization_config[0].state = QuantizationStates.ACTIVATED

# disable input quantization of this network
for name, var in graph.inputs.items():
print(f'Variable {name} has dequantized.')
if isinstance(var, QuantableVariable):
for config in var.dest_op_configs:
config.state = QuantizationStates.FP32

# keep input and output scale of abs as the same.
for op in graph.operations.values():
print(f'Op {op.name} has processed.')
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ PPQ 的开发与推理框架关系密切,这使得我们能够了解硬件推
3. 更为强大的 [图模式匹配](https://github.com/openppl-public/ppq/blob/master/ppq/IR/search.py)[图融合功能](https://github.com/openppl-public/ppq/blob/master/ppq/IR/morph.py)
4. 基于 Onnx 的模型 [QAT](https://github.com/openppl-public/ppq/blob/master/ppq/samples/QAT/imagenet.py) 功能
5. 全新的 [TensorRT](https://github.com/openppl-public/ppq/blob/master/md_doc/deploy_trt_by_OnnxParser.md) 量化与导出逻辑
6. 更多正在更新的样例脚本及视频内容
6. 全球最大的量化模型库 [OnnxQuant](https://github.com/openppl-public/ppq/tree/master/ppq/samples/QuantZoo)
7. 其他未知的软件特性

### Installation (安装方法)
Expand Down
8 changes: 5 additions & 3 deletions ppq/IR/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from ppq import DataType

from .base.graph import BaseGraph
from .processer import GraphCommandProcessor

Expand All @@ -15,20 +17,20 @@ def __init__(self, graph_or_processor: Union[BaseGraph, Callable]) -> None:
def parameters(self) -> List[torch.Tensor]:
parameters = []
for var in self.graph.variables.values():
if var.is_parameter and var.dtype == torch.float:
if var.is_parameter and DataType.to_torch(var.dtype) == torch.float:
parameters.append(var.value)
return parameters

def zero_grad(self):
for var in self.graph.variables.values():
if var.is_parameter and var.dtype == torch.float:
if var.is_parameter and DataType.to_torch(var.dtype) == torch.float:
if var.value._grad is not None:
var.value._grad.zero_()

def state_dict(self) -> dict:
parameters = {}
for var in self.graph.variables.values():
if var.is_parameter and var.dtype == torch.float:
if var.is_parameter and DataType.to_torch(var.dtype) == torch.float:
parameters[var.name] = var.value
return parameters

Expand Down
75 changes: 54 additions & 21 deletions ppq/quantization/algorithm/equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ class EqualizationMethod(Enum):
SQUARE_MAX = 3,
# key value = np.mean(np.square(x))
SQUARE_MEAN = 4,

DIRECT_MAX = 5


class EqualizationHelper():

@ staticmethod
def key_value_from_upstream(
op: Operation, including_bias: bool = False, including_act: bool = False,
op: Operation, including_weight: bool = True,
weight_multiplier: float = 1.0, including_bias: bool = False, including_act: bool = False,
bias_multiplier: float = 0.5, act_multiplier: float = 0.5) -> torch.Tensor:
if op.type not in {'Gemm', 'MatMul', 'Conv', 'ConvTranspose'}:
raise TypeError(f'Unsupported Op type {op.name}({op.type}) for Equalization Optimization.')
Expand All @@ -37,24 +40,25 @@ def key_value_from_upstream(
# ----------------------------------
# step - 1, extract weight from op:
# ----------------------------------
w = op.inputs[1].value
if op.type == 'ConvTranspose':
num_of_groups = op.attributes.get('group', 1)
if w.ndim == 3:
w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ])
w = torch.permute(w, (2, 0, 1, 3))
w = torch.reshape(w, (w.shape[0] * w.shape[1], -1))
elif w.ndim == 4:
w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ])
w = torch.permute(w, (2, 0, 1, 3, 4))
w = torch.reshape(w, (w.shape[0] * w.shape[1], -1))
elif w.ndim == 5:
w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ])
w = torch.permute(w, (2, 0, 1, 3, 4, 5))
w = torch.reshape(w, (w.shape[0] * w.shape[1], -1))
else:
raise ValueError(f'Unexpected dimension of weight of {op.name}.')
buffer.append(w)
if including_weight:
w = op.inputs[1].value * weight_multiplier
if op.type == 'ConvTranspose':
num_of_groups = op.attributes.get('group', 1)
if w.ndim == 3:
w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ])
w = torch.permute(w, (2, 0, 1, 3))
w = torch.reshape(w, (w.shape[0] * w.shape[1], -1))
elif w.ndim == 4:
w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ])
w = torch.permute(w, (2, 0, 1, 3, 4))
w = torch.reshape(w, (w.shape[0] * w.shape[1], -1))
elif w.ndim == 5:
w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ])
w = torch.permute(w, (2, 0, 1, 3, 4, 5))
w = torch.reshape(w, (w.shape[0] * w.shape[1], -1))
else:
raise ValueError(f'Unexpected dimension of weight of {op.name}.')
buffer.append(w)

if op.type in {'MatMul', 'Gemm'}:
assert w.ndim == 2, f'Unexpected Error, Parameter of MatMul {op.name} should be 2-d.'
Expand Down Expand Up @@ -90,11 +94,11 @@ def key_value_from_upstream(
return torch.cat(buffer, dim=-1)

@ staticmethod
def key_value_from_downstream(op: Operation) -> torch.Tensor:
def key_value_from_downstream(op: Operation, weight_multiplier: float = 1.0) -> torch.Tensor:
# ----------------------------------
# step - 1, extract weight from op:
# ----------------------------------
w = op.inputs[1].value
w = op.inputs[1].value * weight_multiplier
if op.type == 'ConvTranspose':
w = torch.reshape(w, (w.shape[0], -1))

Expand Down Expand Up @@ -318,6 +322,8 @@ def __init__(
def equalize(
self,
value_threshold: float,
including_weight: bool = True,
weight_multiplier: float = 1.0,
including_act: bool = False,
act_multiplier: float = 0.5,
including_bias: bool = False,
Expand Down Expand Up @@ -386,6 +392,30 @@ def channel_split(
ChannelSplitHelper.channel_split_downstream(
op = op, scale_factor = 1 / sqrt(2), mask = mask)

def activation_equalize(self, threshold: float = 4.0):
# extract key value from pair
upstream_key_values = []
for op in self.upstream_layers:
key_value = EqualizationHelper.key_value_from_upstream(
op=op, including_bias=False, including_act=True,
bias_multiplier=0, act_multiplier=1, weight_multiplier=0)
upstream_key_values.append(key_value)
upstream_key_values = self.reduce_by_axis(upstream_key_values, method=EqualizationMethod.ABSOLUTE_MAX)

threshold = upstream_key_values.mean().item()
# calculate scale
scale = torch.where(
upstream_key_values > threshold,
torch.ones_like(upstream_key_values) * 1.1,
torch.ones_like(upstream_key_values) / 1.1)

# write back all params
for op in self.upstream_layers:
EqualizationHelper.scale_to_upstream(op, 1 / scale)

for op in self.downstream_layers:
EqualizationHelper.scale_to_downstream(op, 1 / scale)

def calculate_scale(
self, upstream_key_values: torch.Tensor,
downstream_key_values: torch.Tensor,
Expand Down Expand Up @@ -413,6 +443,9 @@ def reduce_by_axis(

elif method is EqualizationMethod.SQUARE_MEAN:
return torch.mean(torch.square(params), axis=axis)

elif method is EqualizationMethod.DIRECT_MAX:
return torch.max(params, axis=axis)

else:
raise NotImplementedError('Equalization method %s is not support.' % str(method))
168 changes: 168 additions & 0 deletions ppq/quantization/algorithm/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,171 @@ def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> to
quantized = quantized
return quantized


class RoundTuningDelegator(TorchQuantizeDelegator):
def __init__(
self, config: TensorQuantizationConfig, var: Variable,
) -> None:
self.config = config
self.is_parameter = var.is_parameter
self.var = var
self.policy = config.policy
self.tunable_round = self.var.value

def trainable_tensors(self) -> List[torch.Tensor]:
params = []
if self.is_parameter: params.append(self.var.value)
return params

def withdraw(self) -> None:
with torch.no_grad():
if self.scale_backup is not None:
self.config.scale.copy_(self.scale_backup)
if self.offset_backup is not None:
self.config.offset.copy_(self.offset_backup)
if self.param_backup is not None:
self.var.value.copy_(self.param_backup)

def finalize(self) -> None:
self.scale_backup = None
self.offset_backup = None
self.param_backup = None
pass # do nothing here.

def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor:
if tensor.is_cuda and PPQ_CONFIG.USING_CUDA_KERNEL:
if config.policy.has_property(QuantizationProperty.LINEAR):
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
return CuLSQ_LC.apply(
tensor, config.scale, config.offset, config.channel_axis,
config.quant_min, config.quant_max, config.rounding)
elif config.policy.has_property(QuantizationProperty.PER_TENSOR):
return CuLSQ_LT.apply(
tensor, config.scale, config.offset,
config.quant_min, config.quant_max, config.rounding)

elif config.policy.has_property(QuantizationProperty.FLOATING):
# For floating quantization, scale is not trainable.
return PPQuantFunction(tensor=tensor, config=config)

else:
scale, offset = config.scale, config.offset

if self.is_scale_trainable:
scale = scale.abs()
grad_scale = 1 / (tensor.numel() * config.quant_max) ** 0.5
scale = scale * grad_scale + (scale - scale * grad_scale).detach()

if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
shape = [1 if axis != config.channel_axis else -1 for axis in range(tensor.ndim)]
scale = scale.view(shape)
offset = offset.view(shape)

quantized = ppq_tensor_round((tensor / scale), config.rounding) + offset.detach()
quantized = torch.clamp(quantized, config.quant_min, config.quant_max)
quantized = (quantized - offset.detach()) * scale
quantized = quantized
return quantized


class TensorwiseRoundTuningImpl(Function):
@ staticmethod
def forward(ctx, tensor: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, quant_min: int, quant_max: int,
rounding: torch.Tensor) -> torch.Tensor:

scale, offset = scale.to(tensor.device), offset.to(tensor.device)
tensor = (tensor / scale) + (rounding > .5) + offset
tensor = torch.clamp(tensor, quant_min, quant_max)
tensor = (tensor - offset) * scale
return tensor

@ staticmethod
def backward(ctx, dy: torch.Tensor):
return dy, None, None, None, None, dy


class ChannelwiseRoundTuningImpl(Function):
@ staticmethod
def forward(ctx, tensor: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, channel_axis: int,
quant_min: int, quant_max: int,
rounding: torch.Tensor) -> torch.Tensor:

scale, offset = scale.to(tensor.device), offset.to(tensor.device)
# generate a shape that likes [1, 1, -1, 1], the only -1 is at channel axe.
shape = [1 if axis != channel_axis else -1 for axis in range(tensor.ndim)]
scale, offset = scale.view(shape), offset.view(shape)

tensor = (tensor / scale) + (rounding > .5) + offset
tensor = torch.clamp(tensor, quant_min, quant_max)
tensor = (tensor - offset) * scale
return tensor

@ staticmethod
def backward(ctx, dy: torch.Tensor):
return dy, None, None, None, None, None, dy


class RoundTruningDelegator(TorchQuantizeDelegator):
def __init__(
self, var: Variable,
config: TensorQuantizationConfig,
) -> None:
self.config = config
self.var = var
self.is_parameter = self.var.is_parameter

# environment check
if config.policy.has_property(QuantizationProperty.FLOATING):
raise TypeError('Incorrect Quantization Property. Except Linear Quantization Policy.')
if config.policy.has_property(QuantizationProperty.DYNAMIC):
raise TypeError('Incorrect Quantization Property. Except Static Quantization Policy.')
if not self.var.is_parameter:
raise TypeError(f'Variable {self.var.name} is not a parameter!')
if self.var.value is None or not isinstance(self.var.value, torch.Tensor):
raise ValueError(f'Unexpected value type of {self.var.name}')
if self.config.scale is None:
raise ValueError(f'Quantization Scale has not been correctly set.')

# initialize rounding
self._calling_times = 0
self._executing_device = self.var.value.device
self._param_backup = self.var.value.clone()

with torch.no_grad():
scale, _ = config.scale, config.offset
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
shape = [1 if axis != config.channel_axis else -1 for axis in range(self.var.value.ndim)]
scale = scale.view(shape)

rounding = ((self.var.value / scale) - (self.var.value / scale).floor())
self.var.value = (self.var.value / scale).floor() * scale
self._scale = scale

# create grad.
self._rounding = rounding
self._rounding.requires_grad = True

def trainable_tensors(self) -> List[torch.Tensor]:
return [self._rounding]

def finalize(self) -> None:
with torch.no_grad():
self.var.value += (self._rounding > .5) * self._scale

def withdraw(self) -> None:
with torch.no_grad():
self.var.value.copy_(self._param_backup)

def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor:
if config.policy.has_property(QuantizationProperty.PER_TENSOR):
return TensorwiseRoundTuningImpl.apply(
tensor, config.scale, config.offset, config.quant_min,
config.quant_max, self._rounding)
elif config.policy.has_property(QuantizationProperty.PER_CHANNEL):
return ChannelwiseRoundTuningImpl.apply(
tensor, config.scale, config.offset, config.channel_axis,
config.quant_min, config.quant_max, self._rounding)
else:
raise Exception('Oops, this should not happen.')
5 changes: 3 additions & 2 deletions ppq/quantization/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
QuantizationOptimizationPipeline)
from .calibration import (IsotoneCalibrationPass, PPLDSPTIReCalibrationPass,
RuntimeCalibrationPass)
from .equalization import ChannelwiseSplitPass, LayerwiseEqualizationPass
from .equalization import (ActivationEqualizationPass, ChannelwiseSplitPass,
LayerwiseEqualizationPass)
from .extension import ExtensionPass
from .legacy import AdaroundPass
from .morph import (GRUSplitPass, HorizontalLayerSplitPass, MetaxGemmSplitPass,
Expand All @@ -13,4 +14,4 @@
NxpQuantizeFusionPass, QuantAlignmentPass,
QuantizeFusionPass, QuantizeSimplifyPass, SwishFusionPass)
from .ssd import SSDEqualizationPass
from .training import BiasCorrectionPass, LearnedStepSizePass
from .training import BiasCorrectionPass, LearnedStepSizePass, RoundTuningPass
Loading

0 comments on commit 1dcd585

Please sign in to comment.