Skip to content

Commit

Permalink
Support 1.0.0 encoding format in aimet_onnx
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <[email protected]>
  • Loading branch information
quic-mtuttle authored Oct 10, 2024
1 parent 0a0ead8 commit 5fa55b3
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 31 deletions.
8 changes: 8 additions & 0 deletions TrainingExtensions/common/src/python/aimet_common/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,11 @@ def __init__(self, func: Callable, func_callback_args=None):
"""
self.func = func
self.args = func_callback_args

class EncodingType(Enum):
""" Encoding type """
PER_TENSOR = 0
PER_CHANNEL = 1
PER_BLOCK = 2
LPBQ = 3
VECTOR = 4
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
# The patching version shall be updated to indicate minor updates to quantization simulation e.g. bug fix etc.
encoding_version = '0.6.1'
ALLOW_EXPERIMENTAL = False
VALID_ENCODING_VERSIONS = {'0.6.1', '1.0.0'}


def gate_min_max(min_val: float, max_val: float) -> Tuple[float, float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from tqdm import tqdm

# Import AIMET specific modules
from aimet_common import quantsim
from aimet_common.utils import AimetLogger
from aimet_common.defs import QuantScheme, QuantizationDataType

Expand Down Expand Up @@ -335,7 +336,7 @@ def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: Q
:param quant_sim: QunatSim that contains the model and Adaround tensor quantizers
"""
# pylint: disable=protected-access
param_encodings = quant_sim._get_encodings(quant_sim.param_names)
param_encodings = quant_sim._get_encodings(quant_sim.param_names, quantsim.encoding_version)

# export encodings to JSON file
os.makedirs(os.path.abspath(path), exist_ok=True)
Expand Down
36 changes: 34 additions & 2 deletions TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
# =============================================================================
""" Custom QcQuantizeOp to quantize weights and activations using ONNXRuntime """

from typing import Union, List, Optional
from typing import Union, List, Optional, Dict
import aimet_common.libpymo as libpymo
from aimet_common.libpymo import TensorQuantizerOpMode
from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType
from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType, EncodingType
from aimet_common import libquant_info
from aimet_common.utils import deprecated

Expand Down Expand Up @@ -462,6 +462,9 @@ def export_encodings(self, encoding_version: str = "0.6.1"):
if encoding_version == '0.6.1':
return self._export_legacy_encodings()

if encoding_version == "1.0.0":
return self._export_1_0_0_encodings()

raise RuntimeError(f"Unsupported encoding export version: {encoding_version}")

def _export_legacy_encodings(self) -> Union[List, None]:
Expand Down Expand Up @@ -490,3 +493,32 @@ def _export_legacy_encodings(self) -> Union[List, None]:
return encodings

raise RuntimeError(f"Exporting data type {self.data_type} not supported")

def _encoding_type(self):
if not self.quant_info.usePerChannelMode:
return EncodingType.PER_TENSOR
if not self.quant_info.blockSize:
return EncodingType.PER_CHANNEL
return EncodingType.PER_BLOCK

def _export_1_0_0_encodings(self) -> Optional[Dict]:
"""
Exports the quantizer's encodings in the "1.0.0" encoding format
"""
if not self.enabled or not self.is_initialized():
return None

enc_dict = dict(enc_type=self._encoding_type().name,
dtype="INT" if self.data_type == QuantizationDataType.int else "FLOAT",
bw=self.bitwidth,
)

if self.data_type == QuantizationDataType.int:
enc_dict["is_sym"] = self.use_symmetric_encodings
encodings = self.get_encodings()
enc_dict["scale"] = [enc.delta for enc in encodings]
enc_dict["offset"] = [enc.offset for enc in encodings]
if self.quant_info.blockSize > 0:
enc_dict["block_size"] = self.quant_info.blockSize

return enc_dict
30 changes: 20 additions & 10 deletions TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@
from packaging import version

# pylint: disable=wrong-import-order
from aimet_common import libpymo
from aimet_common import libpymo, quantsim
from aimet_common import libquant_info
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.quantsim import encoding_version, extract_global_quantizer_args
from aimet_common.quantsim import extract_global_quantizer_args, VALID_ENCODING_VERSIONS
from aimet_common.utils import save_json_yaml, AimetLogger
from aimet_onnx import utils
from aimet_onnx.meta.operations import Op
Expand Down Expand Up @@ -676,25 +676,35 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):
qc_op.compute_encodings()
qc_op.op_mode = OpMode.quantizeDequantize

def _get_encodings(self, quantizer_names) -> Dict:
def _get_encodings(self, quantizer_names, enc_version):
encoding_dict = {}
for name in quantizer_names:
encoding = self.qc_quantize_op_dict[name].export_encodings(encoding_version)
encoding = self.qc_quantize_op_dict[name].export_encodings(enc_version)
if encoding is None:
continue
encoding_dict[name] = encoding
return encoding_dict

def _export_encodings(self, encoding_file_path):
if version.parse(enc_version) < version.parse("1.0.0"):
return encoding_dict

for name, encoding in encoding_dict.items():
encoding["name"] = name
return list(encoding_dict.values())

def _export_encodings(self, encoding_file_path, enc_version):
"""
Export encodings to json and yaml file
:param encoding_file_path: path to save the encoding files
"""
param_encodings = self._get_encodings(self.param_names)
activation_encodings = self._get_encodings(self.activation_names)
if enc_version not in VALID_ENCODING_VERSIONS:
raise NotImplementedError(f'Encoding version {enc_version} not in set of valid encoding '
f'versions {VALID_ENCODING_VERSIONS}.')

param_encodings = self._get_encodings(self.param_names, enc_version)
activation_encodings = self._get_encodings(self.activation_names, enc_version)

encodings_dict = {'version': encoding_version,
encodings_dict = {'version': enc_version,
'activation_encodings': activation_encodings,
'param_encodings': param_encodings,
'quantizer_args': self.quant_args}
Expand Down Expand Up @@ -733,7 +743,7 @@ def export(self, path: str, filename_prefix: str):
:param path: dir to save encoding files
:param filename_prefix: filename to save encoding files
"""
self._export_encodings(os.path.join(path, filename_prefix) + '.encodings')
self._export_encodings(os.path.join(path, filename_prefix) + '.encodings', quantsim.encoding_version)
self.remove_quantization_nodes()
if self.model.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
# Note: Saving as external data mutates the saved model, removing all initializer data
Expand Down
42 changes: 39 additions & 3 deletions TrainingExtensions/onnx/test/python/test_qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import os
import pytest
from aimet_common import libpymo
from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType
from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType, EncodingType
from aimet_onnx.qc_quantize_op import QcQuantizeOp, OpMode, TensorQuantizerParams
from aimet_common import libquant_info
from aimet_common.quantsim import calculate_delta_offset
Expand Down Expand Up @@ -812,7 +812,7 @@ def test_export_per_tensor_int_encodings(self, symmetric, bitwidth, delta, offse
encoding.bw = bitwidth
encoding.offset = offset
encoding.delta = delta
qc_quantize_op.load_encodings([encoding])
qc_quantize_op.update_quantizer_and_load_encodings([encoding], symmetric, False, False, QuantizationDataType.int)
exported_encodings = qc_quantize_op.export_encodings("0.6.1")
assert len(exported_encodings) == 1
assert exported_encodings[0]["scale"] == delta
Expand All @@ -821,11 +821,26 @@ def test_export_per_tensor_int_encodings(self, symmetric, bitwidth, delta, offse
assert exported_encodings[0]["dtype"] == "int"
assert exported_encodings[0]["is_symmetric"] == str(symmetric)

exported_encodings = qc_quantize_op.export_encodings("1.0.0")
assert isinstance(exported_encodings, dict)
assert exported_encodings.keys() == {"enc_type", "dtype", "bw", "is_sym", "scale", "offset"}
assert exported_encodings["dtype"] == "INT"
assert exported_encodings["enc_type"] == EncodingType.PER_TENSOR.name
assert exported_encodings["bw"] == bitwidth
assert exported_encodings["is_sym"] == symmetric
assert isinstance(exported_encodings["scale"], list)
assert isinstance(exported_encodings["offset"], list)
assert len(exported_encodings["scale"]) == 1
assert len(exported_encodings["offset"]) == 1
assert exported_encodings["scale"][0] == delta
assert exported_encodings["offset"][0] == offset

@pytest.mark.parametrize("symmetric, bitwidth, delta, offset", [(True, 8, 0.1, -128),])
def test_export_per_channel_int_encodings(self, symmetric, bitwidth, delta, offset):
channel_axis = 0
block_axis = 1
tensor_shape = [5, 8]
params = TensorQuantizerParams(tensor_shape, channel_axis)
params = TensorQuantizerParams(tensor_shape, channel_axis, block_axis)

quant_info = libquant_info.QcQuantizeInfo()
quant_info.usePerChannelMode = False
Expand All @@ -844,6 +859,22 @@ def test_export_per_channel_int_encodings(self, symmetric, bitwidth, delta, offs
exported_encodings = qc_quantize_op.export_encodings("0.6.1")
assert len(exported_encodings) == tensor_shape[channel_axis]

exported_encodings = qc_quantize_op.export_encodings("1.0.0")
assert exported_encodings.keys() == {"enc_type", "dtype", "bw", "is_sym", "scale", "offset"}
assert exported_encodings["enc_type"] == EncodingType.PER_CHANNEL.name
assert len(exported_encodings["scale"]) == tensor_shape[channel_axis]
assert len(exported_encodings["offset"]) == tensor_shape[channel_axis]

block_size = 4
qc_quantize_op._enable_blockwise_quantization(block_size)
encodings = [libpymo.TfEncoding() for _ in range(tensor_shape[channel_axis] * 2)]
qc_quantize_op.load_encodings(encodings)
exported_encodings = qc_quantize_op.export_encodings("1.0.0")
assert exported_encodings.keys() == {"enc_type", "dtype", "bw", "is_sym", "scale", "offset", "block_size"}
assert exported_encodings["enc_type"] == EncodingType.PER_BLOCK.name
assert len(exported_encodings["scale"]) == tensor_shape[channel_axis] * 2
assert exported_encodings["block_size"] == block_size

def test_export_float_encodings(self):
quant_info = libquant_info.QcQuantizeInfo()
qc_quantize_op = QcQuantizeOp(quant_info, bitwidth=16, op_mode=OpMode.quantizeDequantize)
Expand All @@ -853,6 +884,11 @@ def test_export_float_encodings(self):
assert encodings[0]["dtype"] == "float"
assert encodings[0]["bitwidth"] == 16

exported_encodings = qc_quantize_op.export_encodings("1.0.0")
assert exported_encodings.keys() == {"enc_type", "dtype", "bw"}
assert exported_encodings["dtype"] == "FLOAT"
assert exported_encodings["bw"] == 16

def test_load_float_encodings(self):
quant_info = libquant_info.QcQuantizeInfo()
qc_quantize_op = QcQuantizeOp(quant_info, bitwidth=16, op_mode=OpMode.quantizeDequantize)
Expand Down
50 changes: 49 additions & 1 deletion TrainingExtensions/onnx/test/python/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
# @@-COPYRIGHT-END-@@
# =============================================================================

import contextlib
import itertools
import json
import os
import tempfile
Expand All @@ -46,8 +48,9 @@
import onnxruntime as ort
import pytest

from aimet_common import quantsim
from aimet_common import libquant_info
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.defs import QuantScheme, QuantizationDataType, EncodingType
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_onnx.quantsim import QuantizationSimModel, load_encodings_to_sim, set_blockwise_quantization_for_weights
from aimet_onnx.qc_quantize_op import OpMode
Expand Down Expand Up @@ -103,6 +106,15 @@ def forward(self, inputs):

return x

@contextlib.contextmanager
def set_encoding_version(version):
old_version = quantsim.encoding_version
quantsim.encoding_version = version

yield

quantsim.encoding_version = old_version

class TestQuantSim:
"""Tests for QuantizationSimModel"""
def test_insert_quantize_op_nodes(self):
Expand Down Expand Up @@ -225,6 +237,42 @@ def dummy_callback(session, args):
param_encodings_keys = list(encoding_data["param_encodings"][param][0].keys())
assert param_encodings_keys == ['bitwidth', 'dtype', 'is_symmetric', 'max', 'min', 'offset', 'scale']

def test_export_model_1_0_0(self):
"""Test to export encodings and model in 1.0.0 format"""
model = build_dummy_model()
with tempfile.TemporaryDirectory() as tempdir:
sim = QuantizationSimModel(model, path=tempdir, config_file=get_path_for_per_channel_config())

def dummy_callback(session, _):
session.run(None, make_dummy_input(model))

sim.compute_encodings(dummy_callback, None)
with set_encoding_version("1.0.0"):
sim.export(tempdir, 'quant_sim_model')

with open(os.path.join(tempdir, 'quant_sim_model.encodings'), 'rb') as json_file:
encoding_data = json.load(json_file)

assert encoding_data["version"] == "1.0.0"
assert isinstance(encoding_data["activation_encodings"], list)
assert isinstance(encoding_data["param_encodings"], list)

activation_keys = {enc["name"] for enc in encoding_data["activation_encodings"]}
param_keys = {enc["name"] for enc in encoding_data["param_encodings"]}
assert activation_keys == {'4', '5', 'input', 'output'}
assert param_keys == {'conv_w', 'fc_w'}

for enc in itertools.chain(encoding_data["param_encodings"], encoding_data["activation_encodings"]):
assert isinstance(enc, dict)
assert enc.keys() == {"name", "enc_type", "dtype", "bw", "is_sym", "scale", "offset"}
assert isinstance(enc["scale"], list)
assert enc["dtype"] == "INT"
# Gemm layers do not use per-channel in the default_per_channel_config
if enc["name"] == "conv_w":
assert enc["enc_type"] == EncodingType.PER_CHANNEL.name
else:
assert enc["enc_type"] == EncodingType.PER_TENSOR.name

def test_lstm_gru(self):
"""Test for LSTM and GRU dummy model"""
model = build_lstm_gru_dummy_model()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,17 @@
# =============================================================================
""" Export utilities for QuantizationSimModel """

from enum import Enum
import json
import os
from typing import Dict, List, Tuple

from aimet_common.utils import AimetLogger
from aimet_common.defs import QuantizationDataType
from aimet_common.defs import QuantizationDataType, EncodingType
from aimet_torch.utils import is_vector_encoding


logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)

VALID_ENCODING_VERSIONS = {'0.6.1', '1.0.0'}

class EncodingType(Enum):
""" Encoding type """
PER_TENSOR = 0
PER_CHANNEL = 1
PER_BLOCK = 2
LPBQ = 3
VECTOR = 4

def _export_to_1_0_0(path: str,
filename_prefix: str,
tensor_to_activation_encodings: Dict[str, List],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from aimet_common.connected_graph.connectedgraph_utils import CG_SPLIT
from aimet_common.utils import AimetLogger, save_json_yaml, log_with_error_and_assert_if_false
from aimet_common.defs import QuantScheme, QuantizationDataType, SupportedKernelsAction, QuantDtypeBwInfo
from aimet_common.quantsim import validate_quantsim_inputs, extract_global_quantizer_args
from aimet_common.quantsim import validate_quantsim_inputs, extract_global_quantizer_args, VALID_ENCODING_VERSIONS
from aimet_common.quant_utils import get_conv_accum_bounds

from aimet_torch.v1.nn.modules.custom import MatMul
Expand All @@ -76,7 +76,7 @@
from aimet_torch.meta.connectedgraph import ConnectedGraph, Op
from aimet_torch.qc_quantize_recurrent import QcQuantizeRecurrent
from aimet_torch.quantsim_config.builder import LazyQuantizeWrapper
from aimet_torch.experimental.v2.quantsim.export_utils import VALID_ENCODING_VERSIONS, _export_to_1_0_0
from aimet_torch.experimental.v2.quantsim.export_utils import _export_to_1_0_0


logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
Expand Down

0 comments on commit 5fa55b3

Please sign in to comment.