Skip to content

Commit

Permalink
Take list in _set_tensor_quantizer
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <[email protected]>
  • Loading branch information
quic-mtuttle committed Oct 9, 2024
1 parent bb03529 commit a636cf2
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, quant_info: libquant_info.QcQuantizeInfo,
self.rounding_mode = rounding_mode
self._is_encoding_frozen = False
self._tensor_quantizer = None
self._set_tensor_quantizer(self._build_tensor_quantizer())
self._set_tensor_quantizer([self._build_tensor_quantizer()])
self.op_mode = op_mode
self.bitwidth = bitwidth
self.use_symmetric_encodings = use_symmetric_encodings
Expand Down Expand Up @@ -126,9 +126,7 @@ def _create_tensor_quantizers(self, num: int):
tensor_quantizer.isEncodingValid = False
tensor_quantizers.append(tensor_quantizer)

self._tensor_quantizer = tensor_quantizers
self.quant_info.tensorQuantizerRef = [libpymo.PtrToInt64(tensor_quantizer)
for tensor_quantizer in tensor_quantizers]
self._set_tensor_quantizer(tensor_quantizers)

self.reset_encoding_stats()

Expand Down Expand Up @@ -178,14 +176,14 @@ def _build_tensor_quantizer(self):
return libpymo.TensorQuantizer(MAP_QUANT_SCHEME_TO_PYMO[self.quant_scheme],
MAP_ROUND_MODE_TO_PYMO[self.rounding_mode])

def _set_tensor_quantizer(self, tensor_quantizer: libpymo.TensorQuantizer):
def _set_tensor_quantizer(self, tensor_quantizers: List[libpymo.TensorQuantizer]):
"""
Stores tensor_quantizer in self._tensor_quantizer and passes a pointer to the object
to the C++ op's QcQuantInfo object
:param tensor_quantizer: The libpymo.TensorQuantizer object to give to the C++ op
:param tensor_quantizers: The list of libpymo.TensorQuantizer objects to give to the C++ op
"""
self._tensor_quantizer = [tensor_quantizer]
self.quant_info.tensorQuantizerRef = [libpymo.PtrToInt64(tensor_quantizer)]
self._tensor_quantizer = tensor_quantizers
self.quant_info.tensorQuantizerRef = [libpymo.PtrToInt64(tensor_quantizer) for tensor_quantizer in tensor_quantizers]

@property
def enabled(self) -> bool:
Expand Down Expand Up @@ -400,7 +398,7 @@ def set_quant_scheme(self, quant_scheme: QuantScheme):
if self.quant_info.usePerChannelMode:
self.enable_per_channel_quantization()
else:
self._set_tensor_quantizer(self._build_tensor_quantizer())
self._set_tensor_quantizer([self._build_tensor_quantizer()])
self.reset_encoding_stats()

def compute_encodings(self) -> Optional[List[libpymo.TfEncoding]]:
Expand Down

0 comments on commit a636cf2

Please sign in to comment.