Skip to content

Commit

Permalink
Enable seq mse with bq/lpbq
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <[email protected]>
  • Loading branch information
quic-klhsieh authored Sep 26, 2024
1 parent 3cd23ce commit eb6337d
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 53 deletions.
8 changes: 4 additions & 4 deletions TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,12 +539,12 @@ def compute_outputs(cls,
module = cls._get_original_module(quant_module)

if isinstance(module, torch.nn.Linear):
xqwq = functional.linear(xq, wq, module.bias)
xw = functional.linear(x, w, module.bias)
xqwq = functional.linear(xq, wq)
xw = functional.linear(x, w)
elif isinstance(module, torch.nn.Conv2d):
xqwq = functional.conv2d(xq, wq, bias=module.bias, stride=module.stride, dilation=module.dilation,
xqwq = functional.conv2d(xq, wq, stride=module.stride, dilation=module.dilation,
padding=module.padding, groups=module.groups)
xw = functional.conv2d(x, w, bias=module.bias, stride=module.stride, dilation=module.dilation,
xw = functional.conv2d(x, w, stride=module.stride, dilation=module.dilation,
padding=module.padding, groups=module.groups)

# [N, C, H, W] --> [N, H, W, C], so that loss can be computed across channel dimension.
Expand Down
255 changes: 223 additions & 32 deletions TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,36 +42,71 @@
import contextlib
import torch
from torch import nn
from torch.utils.data import DataLoader

from aimet_common.utils import AimetLogger
from aimet_torch.seq_mse import SequentialMse as V1SequentialMse
from aimet_torch.seq_mse import SeqMseParams as V1SeqMseParams
from aimet_torch.seq_mse import SUPPORTED_MODULES
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.affine import AffineQuantizerBase
from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer
from aimet_torch.v2.quantization.affine import AffineQuantizerBase, QuantizeDequantize, GroupedBlockQuantizeDequantize
from aimet_torch.v2.quantization.affine.backends import torch_builtins
from aimet_torch.v2.nn.base import BaseQuantizationMixin
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.utils import reduce, _is_reducible


SeqMseParams = V1SeqMseParams


def _observe(x_min: torch.Tensor,
x_max: torch.Tensor,
num_steps: int,
symmetric: bool) -> Tuple[torch.Tensor, torch.Tensor]:
encoding_analyzer = MinMaxEncodingAnalyzer(x_min.shape)
min, max = encoding_analyzer.compute_dynamic_encodings(torch.stack([x_min, x_max]),
num_steps=num_steps,
is_symmetric=symmetric)
return min, max

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.SeqMse)

class SequentialMse(V1SequentialMse):
"""
Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.
"""

@classmethod
def apply_seq_mse(cls,
model: torch.nn.Module,
sim: QuantizationSimModel,
data_loader: DataLoader,
params: SeqMseParams,
modules_to_exclude: Optional[List[torch.nn.Module]] = None,
checkpoints_config: Optional[str] = None):
if not modules_to_exclude:
modules_to_exclude = []
modules_to_exclude.extend(cls._get_grouped_convs_with_blockwise_quantization(sim))
with cls._handle_grouped_block_quantizers(sim):
super().apply_seq_mse(model, sim, data_loader, params, modules_to_exclude, checkpoints_config)

@staticmethod
def _get_grouped_convs_with_blockwise_quantization(sim):
""" Return a list of all grouped conv modules using blockwise quantization for weights """
grouped_convs_with_blockwise_quantization = []
for module in sim.model.modules():
if isinstance(module, torch.nn.Conv2d) and \
isinstance(module, BaseQuantizationMixin) and \
module.groups != 1 and \
module.param_quantizers['weight'].block_size is not None and \
module.param_quantizers['weight'].block_size[1] != module.weight.shape[1]:
grouped_convs_with_blockwise_quantization.append(module)
return grouped_convs_with_blockwise_quantization

@staticmethod
@contextlib.contextmanager
def _handle_grouped_block_quantizers(sim: QuantizationSimModel):
""" Set all grouped block quantizers to regular blockwise quantization for the duration of the context manager
"""
grouped_block_quantize_dequantizers = []
for module in sim.model.modules():
if isinstance(module, GroupedBlockQuantizeDequantize):
grouped_block_quantize_dequantizers.append((module, module.block_grouping))
module.block_grouping = tuple(1 for _ in enumerate(module.shape))

yield

for (module, block_grouping) in grouped_block_quantize_dequantizers:
module.block_grouping = block_grouping

@staticmethod
def compute_all_param_encodings(sim: QuantizationSimModel):
"""
Expand Down Expand Up @@ -143,27 +178,15 @@ def compute_param_encodings(quantizer: QuantizerBase,
:param x_min: min values
:param x_max: max values
"""
# Unsqueeze x_min/x_max until they become reducible to quantizer.min/max
while x_min.dim() < quantizer.min.dim():
x_min = x_min[..., None]
while x_max.dim() < quantizer.max.dim():
x_max = x_max[..., None]
assert _is_reducible(x_min.shape, quantizer.min.shape)
assert _is_reducible(x_max.shape, quantizer.max.shape)

x_min = reduce(x_min, quantizer.shape, torch.min).values
x_max = reduce(x_max, quantizer.shape, torch.max).values
quantize_dequantize = QuantizeDequantize(quantizer.shape, quantizer.bitwidth, quantizer.symmetric,
block_size=quantizer.block_size).to(x_min.device)

num_steps = 2 ** quantizer.bitwidth - 1
symmetric = quantizer.symmetric

# The values of x_min and x_max don't necessarily satisfy the symmetry constraints.
# Therefore, we need to adjust their values to ensure min and max are in symmetric grids.
min, max = _observe(x_min, x_max, num_steps=num_steps, symmetric=symmetric)
with quantize_dequantize.compute_encodings():
_ = quantize_dequantize(torch.stack([x_min, x_max]))

with torch.no_grad():
quantizer.min.copy_(min)
quantizer.max.copy_(max)
quantizer.min.copy_(quantize_dequantize.min)
quantizer.max.copy_(quantize_dequantize.max)

@staticmethod
def _is_symmetric_quantizer(quantizer: AffineQuantizerBase):
Expand All @@ -185,6 +208,174 @@ def _get_quantized_weight(quant_module: BaseQuantizationMixin):
def _get_original_module(quant_module: BaseQuantizationMixin):
return quant_module

@staticmethod
def _get_input_channel_block_size(quant_module):
if not isinstance(quant_module, (torch.nn.Linear, torch.nn.Conv2d)):
raise NotImplementedError('Unsupported module type: ', type(quant_module))
if quant_module.param_quantizers['weight'].block_size is None:
# Per tensor or per channel case. For either one, treat loss computation as per channel
return quant_module.weight.shape[1]
return quant_module.weight.shape[1] // quant_module.param_quantizers['weight'].shape[1]

@staticmethod
def _get_indices_to_reduce(block_size, reshaped_weight):
"""
Return indices in reshaped_weight corresponding to block_sizes. Reshaped_weight is expected to contain
alternating dimensions of num_blocks and block_sizes.
"""
indices_to_reduce = []
for idx, _ in enumerate(block_size):
indices_to_reduce.insert(0, (len(reshaped_weight.shape) - 2 * idx) - 1)
return indices_to_reduce

@classmethod
def get_min_and_max_for_candidate_selection(cls, quant_module: BaseQuantizationMixin) -> \
Tuple[torch.Tensor, torch.Tensor]:
"""
Get min/max values for candidate selection.
:param quant_module: Quant module to be optimized
:return: Tuple of min and max values for candidate selection.
"""
# pylint: disable=protected-access
assert hasattr(quant_module.param_quantizers['weight'], 'block_size')
if not isinstance(quant_module, (torch.nn.Conv2d, torch.nn.Linear)):
raise ValueError('Unsupported module: ', quant_module)

block_size = quant_module.param_quantizers['weight'].block_size
if block_size is None:
# Per tensor or per channel case
assert _is_reducible(quant_module.weight.shape, quant_module.param_quantizers['weight'].min.shape)
if cls._is_symmetric_quantizer(quant_module.param_quantizers['weight']):
max_tensor = reduce(quant_module.weight.abs(),
quant_module.param_quantizers['weight'].shape, torch.max).values
min_tensor = -max_tensor
else:
max_tensor = reduce(quant_module.weight,
quant_module.param_quantizers['weight'].shape, torch.max).values
min_tensor = reduce(quant_module.weight,
quant_module.param_quantizers['weight'].shape, torch.min).values
else:
# Reshape tensor so each dimension is split into (num_blocks, block_size)
weight = torch_builtins.reshape_tensor_for_blocks(quant_module.weight,
quant_module.param_quantizers['weight'].shape,
block_size)
indices_to_reduce = cls._get_indices_to_reduce(block_size, weight)

# Obtain max_tensor and min_tensor which are equivalent in shape to the original weight, but with block
# values modified to be the block minimum and maximum.
# For example assume the original weight is 1 output channel and 6 input channels, with block size 2:
# Original weight: [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]
# Then, max tensor would be: [[2.0, 2.0, 4.0, 4.0, 6.0, 6.0]]
if cls._is_symmetric_quantizer(quant_module.param_quantizers['weight']):
max_tensor = torch.maximum(weight,
torch.amax(weight.abs(),
indices_to_reduce,
keepdim=True)).reshape(quant_module.weight.shape)
min_tensor = -max_tensor
else:
max_tensor = torch.maximum(weight,
torch.amax(weight,
indices_to_reduce,
keepdim=True)).reshape(quant_module.weight.shape)
min_tensor = torch.minimum(weight,
torch.amin(weight,
indices_to_reduce,
keepdim=True)).reshape(quant_module.weight.shape)

return min_tensor, max_tensor

@classmethod
def _get_candidate(cls, candidate_idx: int, num_candidates: int, min_tensor: torch.Tensor,
max_tensor: torch.Tensor):
"""
Get candidate min and max tensors
"""
cand_max = max_tensor / num_candidates * (candidate_idx + 1)
cand_min = min_tensor / num_candidates * (candidate_idx + 1)
return cand_min, cand_max

@classmethod
def _compute_loss(cls,
quant_module: BaseQuantizationMixin,
x: torch.Tensor,
xq: torch.Tensor,
w: torch.Tensor,
wq: torch.Tensor,
params) -> torch.Tensor:
"""
Compute loss for the given (x, w) and (xq, wq) input/weight pairs. Assumes that block size will be on
input_channel dimension.
"""
# pylint: disable=too-many-locals
# General strategy: split weights and input per block, and run a separate forward pass for each split.
# In the case of per tensor and per channel, the entire input channel is treated as one block.
block_size = cls._get_input_channel_block_size(quant_module)
w_blocks = torch.split(w, block_size, dim=1)
wq_blocks = torch.split(wq, block_size, dim=1)
if isinstance(quant_module, torch.nn.Linear):
x_blocks = torch.split(x, block_size, dim=-1)
xq_blocks = torch.split(xq, block_size, dim=-1)
else:
x_blocks = torch.split(x, block_size, dim=-3)
xq_blocks = torch.split(xq, block_size, dim=-3)

block_losses = []
for idx, x_block in enumerate(x_blocks):
xqwq, xw = cls.compute_outputs(quant_module, x_block, xq_blocks[idx], w_blocks[idx], wq_blocks[idx])
block_losses.append(cls.compute_recon_loss(xqwq, xw, params))
# Stack losses in the input channel dimension
block_losses = torch.stack(block_losses, dim=-1)
return block_losses

@classmethod
def optimize_module(cls,
quant_module: BaseQuantizationMixin,
x: torch.Tensor,
xq: torch.Tensor,
params: SeqMseParams):
"""
Find and freeze optimal parameter encodings candidate for given module.
:param quant_module: Quant module to be optimized
:param x: Inputs to module from FP32 model
:param xq: Inputs to module from QuantSim model
:param params: Sequenial MSE parameters
"""
# pylint: disable=too-many-locals
min_tensor, max_tensor = cls.get_min_and_max_for_candidate_selection(quant_module)

total_loss = []
for i in range(params.num_candidates):
cand_min, cand_max = cls._get_candidate(i, params.num_candidates, min_tensor, max_tensor)
cls.compute_param_encodings(quant_module.param_quantizers['weight'], cand_min, cand_max)
w = quant_module.weight
wq = cls._get_quantized_weight(quant_module)
with torch.no_grad():
for batch_idx in range(params.num_batches):
if batch_idx == 0:
loss = cls._compute_loss(quant_module, x[batch_idx], xq[batch_idx], w, wq, params)
else:
loss += cls._compute_loss(quant_module, x[batch_idx], xq[batch_idx], w, wq, params)
total_loss.append(loss)

best_indices = torch.stack(total_loss).min(0)[1]
block_size = cls._get_input_channel_block_size(quant_module)
# In the input_channels dimension, best_indices is of size num_blocks. We use repeat_interleave to expand
# each blockwise index into block_size number of indices. This makes best_indices input_channels dimension
# become size num_blocks * block_size, and allows for elementwise operation with min_tensor and max_tensor.
if block_size != quant_module.weight.shape[1]:
best_indices = best_indices.repeat_interleave(block_size, dim=-1)

# Unsqueeze best_indices until it matches dim length of max_tensor
while best_indices.dim() < max_tensor.dim():
best_indices = best_indices[..., None]

min_tensor, max_tensor = cls._get_candidate(best_indices, params.num_candidates, min_tensor, max_tensor)

# Compute and freeze parameter encodings using best candidate
cls.compute_param_encodings(quant_module.param_quantizers['weight'], min_tensor, max_tensor)
cls._freeze_quantizer_encoding(quant_module.param_quantizers['weight'])

# Global variables for compatibility
apply_seq_mse = SequentialMse.apply_seq_mse
Expand Down
Loading

0 comments on commit eb6337d

Please sign in to comment.