Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix datatype issue using native torch quantizer on torch1.13 #2309

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions TrainingExtensions/torch/src/python/aimet_torch/torch_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import math
from typing import Tuple, Union
import torch
from packaging import version

import aimet_common.libpymo as libpymo
from aimet_common.utils import AimetLogger
Expand All @@ -49,6 +50,10 @@

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)

torch_quantizer_zero_ponit_data_type = torch.int64
if version.parse(torch.__version__) > version.parse('1.10.2'):
torch_quantizer_zero_ponit_data_type = torch.int32


def calc_params_for_native_torch_quantizer(quantizer, ch_axis, device: torch.device) \
-> Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int], int, int]:
Expand All @@ -62,6 +67,7 @@ def calc_params_for_native_torch_quantizer(quantizer, ch_axis, device: torch.dev

numSteps = pow(2, quantizer.bitwidth) - 1
encodings = quantizer.encoding

if quantizer.use_strict_symmetric:
error_msg = ('Strict symmetric is not supported by native torch quantizer')
logger.error(error_msg)
Expand All @@ -84,13 +90,13 @@ def calc_params_for_native_torch_quantizer(quantizer, ch_axis, device: torch.dev
else:
# Per Channel quantization
scale = torch.tensor([encoding.delta for encoding in encodings], device=device)
zero_point = torch.tensor([int(-encoding.offset) for encoding in encodings], device=device)
zero_point = torch.tensor([int(-encoding.offset) for encoding in encodings], device=device, dtype=torch_quantizer_zero_ponit_data_type)
if quantizer.use_symmetric_encodings and (all([encoding.min < 0 for encoding in encodings])
or (not quantizer.use_unsigned_symmetric)):
# Symmetric quantization
q_max = math.floor(numSteps / 2)
q_min = -math.ceil(numSteps / 2)
zero_point = torch.zeros_like(zero_point)
zero_point = torch.zeros_like(zero_point, dtype=torch_quantizer_zero_ponit_data_type)
else:
# Unsigned symmetric
q_min, q_max = 0, numSteps
Expand All @@ -114,9 +120,10 @@ def __init__(self, quantizer: Union[StaticGridTensorQuantizer, LearnedGridTensor
self.enabled = quantizer.enabled
self.data_type = quantizer.data_type
self.bitwidth = quantizer.bitwidth
self._ch_axis = None

if self.data_type == QuantizationDataType.float and self.bitwidth != 16:
raise ValueError('Only FP16 quantizers are supported by TorchQuantizer')
self._ch_axis = None
encodings = quantizer.encoding
# To aviod quantizer.enabled is True but quantizer.encoding is None
if quantizer.enabled and quantizer.encoding:
Expand Down
Loading