diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index 93ef119083..82a2923a48 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -36,6 +36,7 @@ # ============================================================================= """ Implementation for simulating models running on Quantized hardware """ +import contextlib import tempfile from dataclasses import dataclass from pathlib import Path @@ -58,6 +59,7 @@ from aimet_common.defs import QuantScheme, QuantizationDataType from aimet_common.quantsim import extract_global_quantizer_args, VALID_ENCODING_VERSIONS from aimet_common.utils import save_json_yaml, AimetLogger +from aimet_common.connected_graph.product import Product from aimet_onnx import utils from aimet_onnx.meta.operations import Op from aimet_onnx.meta.utils import get_op_given_param_name, get_param_shape_using_connected_graph @@ -83,8 +85,31 @@ allowed_op_type_for_per_channel = ['Conv', 'Gemm', 'MatMul', 'ConvTranspose'] +# List of op types whose input and output quantizers to be tied +op_types_to_tie_qtzrs = ['Concat', 'MaxPool', 'AveragePool', 'Resize'] +_tie_qtzrs = False + data_types_to_quantize = [np.float32] + +@contextlib.contextmanager +def _apply_constraints(flag: bool): + """ + Apply runtime specific constraints. + For certain ``op_types_to_tie_qtzrs``, runtime has constraints to have same encodings for + input and output quantizers. + + NOTE: Default setting doesn't apply these constraints. + """ + global _tie_qtzrs # pylint: disable=global-statement + orig_flag = _tie_qtzrs + try: + _tie_qtzrs = flag + yield + finally: + _tie_qtzrs = orig_flag + + @dataclass class EncodingMismatchInfo: """ @@ -185,25 +210,27 @@ def __init__(self, self._path = path if path else tempfile.mkdtemp() if not os.path.exists(self._path): os.makedirs(self._path, exist_ok=True) + + # Get names of parameters and activations to quantize self._get_param_names() self._get_activations_to_quantize(dummy_input) # Disable bias quantization self._disable_bias_quantization() - self._add_quantization_nodes() - self.session = QuantizationSimModel.build_session(self.model.model, self.providers, - user_onnx_libs=self._user_onnx_libs, path=self._path) + # Apply configurations based on provided config file. quantsim_configurator = self._add_configuration_(config_file) - self._hw_version = quantsim_configurator._get_hw_version() self._supported_kernels = quantsim_configurator.get_supported_kernels() self._op_to_supported_kernel = quantsim_configurator.get_op_to_supported_kernels() - self.quant_args = extract_global_quantizer_args(quant_scheme, quantsim_configurator) - self._apply_exception_rules() + self._tie_quantizers() + + # Build onnxruntime inference session + self.session = QuantizationSimModel.build_session(self.model.model, self.providers, + user_onnx_libs=self._user_onnx_libs, path=self._path) def get_supported_kernels(self) -> Dict: """ @@ -548,7 +575,7 @@ def get_op_quantizers(self, op: Op) -> (List, List, Dict): if param_name in self.qc_quantize_op_dict: param_quantizers[param_type] = self.qc_quantize_op_dict[param_name] - return (input_quantizers, output_quantizers, param_quantizers) + return input_quantizers, output_quantizers, param_quantizers def _apply_exception_rules(self): """ @@ -792,6 +819,95 @@ def get_all_quantizers(self) -> Tuple[List, List]: return param_quantizers, activation_quantizers + def _tie_quantizers(self): + """ + Tie the input and output quantizers for given op types. + """ + if not _tie_qtzrs: + return + + cg = self.connected_graph + + def _set_quant_info(dst_qtzr_node_name: str, src_qtzr: QcQuantizeOp): + """ + Set quant_info attribute (pointer to the libquant_info object) + + :param dst_qtzr_node_name: destination quantizer node name in graph. + :param src_qtzr: source quantizer. + """ + for node in self.model.graph().node: + if node.op_type == 'QcQuantizeOp' and node.name == dst_qtzr_node_name: + for atr in node.attribute: + if atr.name == "quant_info": + atr.i = libpymo.PtrToInt64(src_qtzr.quant_info) + return + + def _set_qtzr(dst_qtzr: QcQuantizeOp, src_qtzr: QcQuantizeOp): + """ + Set the dst quantizer by src quantizer and update quant_info attribute (pointer to the libquant_info object) + in the graph node. + + :param dst_qtzr: destination quantizer. + :param src_qtzr: source quantizer + """ + for name, qtzr in self.qc_quantize_op_dict.items(): + if dst_qtzr == qtzr: + self.qc_quantize_op_dict[name] = src_qtzr + dst_qtzr_node_name = 'QcQuantizeOp_' + name + # update quant_info attribute (pointer to the libquant_info object) in the graph node. + _set_quant_info(dst_qtzr_node_name, src_qtzr) + return + + def _set_src_qtzr(x: Product, consumer: Op, src_qtzr): + producer = x.producer + + if not producer: + # ``x`` is a root input (i.e. has no producer). + # In this case, set the input quantizer of the consumer to ``src_qtzr`` + i = consumer.inputs.index(x) + inp_qtzr, _, __ = self.get_op_quantizers(consumer) + if i >= len(inp_qtzr): + return + + _set_qtzr(dst_qtzr=inp_qtzr[i], src_qtzr=src_qtzr) + return + + _, out_qtzr, __ = self.get_op_quantizers(producer) + + if out_qtzr: + # There exists output quantizer associated with the graph node ``producer`` + # In this case, set the output quantizer of the producer to ``src_qtzr` + outputs = [producer.output] + i = outputs.index(x) + _set_qtzr(dst_qtzr=out_qtzr[i], src_qtzr=src_qtzr) + + if not out_qtzr or producer.type in op_outputs_to_ignore: + # 1. There is no output quantizer associated with the graph node ``producer``, or + # 2. op is a math invariant op (reshape, permute, etc.). + # In these cases, propagate encoding further to the ancestors + for inp in producer.inputs: + _set_src_qtzr(inp, consumer=producer, src_qtzr=src_qtzr) + + for op in reversed(cg.ordered_ops): + if op.type not in op_types_to_tie_qtzrs: + continue + + _, out_qtzr, __ = self.get_op_quantizers(op) + + if not out_qtzr: + msg = 'Encoding propagation is only supported for ops with exactly ' \ + '1 output quantizer, but found output_quantizers[0] == []' + raise RuntimeError(msg) + + if len(out_qtzr) != 1: + msg = 'Encoding propagation is only supported for ops with exactly ' \ + f'1 output quantizer, but found {len(out_qtzr)} ' \ + 'output quantizers' + raise RuntimeError(msg) + + for inp in op.inputs: + _set_src_qtzr(inp, consumer=op, src_qtzr=out_qtzr[0]) + def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_path: str, strict=True) -> \ List[EncodingMismatchInfo]: diff --git a/TrainingExtensions/onnx/test/python/models/models_for_tests.py b/TrainingExtensions/onnx/test/python/models/models_for_tests.py index 922dea396e..ace8d8ca35 100644 --- a/TrainingExtensions/onnx/test/python/models/models_for_tests.py +++ b/TrainingExtensions/onnx/test/python/models/models_for_tests.py @@ -1634,31 +1634,35 @@ def forward(self, inputs): return x, y -def _convert_to_onnx_no_fold(model: torch.nn.Module, dummy_input, filename='./temp_model.onnx'): - torch.onnx.export(model.eval(), - dummy_input, - filename, - training=torch.onnx.TrainingMode.PRESERVE, - export_params=True, - opset_version=12, - do_constant_folding=False, - input_names=['input'], - output_names=['output']) - model = ONNXModel(load_model(filename)) +def _convert_to_onnx_no_fold(model: torch.nn.Module, dummy_input, filename='temp_model.onnx'): + with tempfile.TemporaryDirectory() as tmp_dir: + save_path = os.path.join(tmp_dir, filename) + torch.onnx.export(model.eval(), + dummy_input, + save_path, + training=torch.onnx.TrainingMode.PRESERVE, + export_params=True, + opset_version=12, + do_constant_folding=False, + input_names=['input'], + output_names=['output']) + model = ONNXModel(load_model(save_path)) return model -def _convert_to_onnx(model: torch.nn.Module, dummy_input, filename='./temp_model.onnx'): - torch.onnx.export(model.eval(), - dummy_input, - filename, - training=torch.onnx.TrainingMode.EVAL, - export_params=True, - opset_version=12, - do_constant_folding=True, - input_names=['input'], - output_names=['output']) - model = ONNXModel(load_model(filename)) +def _convert_to_onnx(model: torch.nn.Module, dummy_input, filename='temp_model.onnx'): + with tempfile.TemporaryDirectory() as tmp_dir: + save_path = os.path.join(tmp_dir, filename) + torch.onnx.export(model.eval(), + dummy_input, + save_path, + training=torch.onnx.TrainingMode.EVAL, + export_params=True, + opset_version=12, + do_constant_folding=True, + input_names=['input'], + output_names=['output']) + model = ONNXModel(load_model(save_path)) return model diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index 07e9b01be2..7d90d6efa1 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -52,14 +52,20 @@ from aimet_common import libquant_info from aimet_common.defs import QuantScheme, QuantizationDataType, EncodingType from aimet_common.quantsim_config.utils import get_path_for_per_channel_config -from aimet_onnx.quantsim import QuantizationSimModel, load_encodings_to_sim, set_blockwise_quantization_for_weights +from aimet_onnx.quantsim import QuantizationSimModel, load_encodings_to_sim, set_blockwise_quantization_for_weights, _apply_constraints from aimet_onnx.qc_quantize_op import OpMode from aimet_onnx.utils import make_dummy_input -from aimet_onnx import utils from models.models_for_tests import SingleResidual from models import models_for_tests, test_models from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model , multi_output_model, custom_add_model, build_lstm_gru_dummy_model, \ - transposed_conv_model, depthwise_transposed_conv_model, linear_split_into_matmul_add + transposed_conv_model, depthwise_transposed_conv_model, linear_split_into_matmul_add, _convert_to_onnx + + +def _compare_encodings(dst, src): + return (dst.min == src.min and + dst.max == src.max and + dst.delta == src.delta and + dst.offset == src.offset) class DummyModel(SingleResidual): @@ -1264,3 +1270,182 @@ def callback(session, dummy_input): with open(os.path.join(tempdir, 'gather_model.encodings')) as json_file: encoding_data = json.load(json_file) assert 'gather_weight' not in encoding_data['activation_encodings'].keys() + +class TestEncodingPropagation: + + def test_output(self): + """ + Given: model as below + + +-> q_in1 -> conv1 -> relu1 ---> q_out1 -------v + [input] -+ concat -> q_out3 -> [output] + +-> q_in2 -> conv2 -> relu2 ---> q_out2 -------^ + """ + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3,3,3) + self.relu1 = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(3,3,3) + self.relu2 = torch.nn.ReLU() + + def forward(self, x): + x1 = x2 = x + x1 = self.conv1(x1) + x1 = self.relu1(x1) + x2 = self.conv2(x2) + x2 = self.relu2(x2) + return torch.cat([x1, x2]) + """ + When: _apply_constraints(True) + + Then: q_out1 and q_out2 are replaced with q_out3 as below + + +-> q_in1 -> conv1 -> relu1 -> **q_out3** -----v + [input] -+ concat -> q_out3- > [output] + +-> q_in2 -> conv2 -> relu2 -> **q_out3** -----^ + """ + pt_model = Model().eval() + x = torch.randn(1, 3, 24, 24) + model = _convert_to_onnx(pt_model, x) + dummy_input = make_dummy_input(model.model) + with _apply_constraints(True): + sim = QuantizationSimModel(model, dummy_input) + + sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) + assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0], + sim.qc_quantize_op_dict['output'].encodings[0]) + assert _compare_encodings(sim.qc_quantize_op_dict['/relu2/Relu_output_0'].encodings[0], + sim.qc_quantize_op_dict['output'].encodings[0]) + + def test_math_invariant(self): + """ + Given: model as below + + +-> q_in1 -> conv1 ---> relu1 -> q_out1 ------v + [input] -+ concat -> q_out2 -> [output] + +-> q_in2 -> reshape -> permute --------------^ + """ + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3, padding=1) + self.relu1 = torch.nn.ReLU() + + def forward(self, x): + x1 = x2 = x + x1 = self.conv1(x1) + x1 = self.relu1(x1) + x2 = torch.reshape(x2, (-1, 24, 24, 3)) + x2 = torch.permute(x2, (0, 3, 1, 2)) + return torch.cat([x1, x2]) + """ + When: _apply_constraints(True) + + Then: q_out1 and q_in2 are replaced with q_out3 as below + + +-> q_in1 -> conv1 ---> relu1 -----> **q_out2**- --------v + [input] -+ concat -> q_out2 -> [output] + +-> **q_out2** -> reshape -> transpose -> permute -------^ + """ + pt_model = Model().eval() + dummy_input = torch.randn(1, 3, 24, 24) + model = _convert_to_onnx(pt_model, dummy_input) + dummy_input = make_dummy_input(model.model) + with _apply_constraints(True): + sim = QuantizationSimModel(model, dummy_input) + sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) + + assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0], + sim.qc_quantize_op_dict['output'].encodings[0]) + assert _compare_encodings(sim.qc_quantize_op_dict['input'].encodings[0], + sim.qc_quantize_op_dict['output'].encodings[0]) + + def test_concat_tree(self): + """ + Given: model as below + + +-> q_in1a -> conv1a -> q_out1a -> concat1 -> q_out1c -> reshape --+ + +-> q_in1b -> conv1b -> q_out1b ------^ v + [input] --+ concat3 -> q_out3 -> [output] + +-> q_in2a -> conv2a -> q_out2a -> concat2 -> q_out2c -------------^ + +-> q_in2b -> conv2b -> q_out2b ------^ + """ + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1a = torch.nn.Conv2d(3,3,3) + self.conv1b = torch.nn.Conv2d(3,3,3) + self.conv2a = torch.nn.Conv2d(3,3,3) + self.conv2b = torch.nn.Conv2d(3,3,3) + + def forward(self, x): + x1a = x1b = x2a = x2b = x + x1a = self.conv1a(x1a) + x1b = self.conv1b(x1b) + x1 = torch.cat([x1a, x1b]) + x1 = torch.reshape(x1, (-1, 22, 22, 3)) + x1 = torch.permute(x1, (0, 3, 1, 2)) + x2a = self.conv2a(x2a) + x2b = self.conv2b(x2b) + x2 = torch.cat([x2a, x2b]) + return torch.cat([x1, x2]) + + pt_model = Model().eval() + dummy_input = torch.randn(1, 3, 24, 24) + model = _convert_to_onnx(pt_model, dummy_input) + dummy_input = make_dummy_input(model.model) + """ + When: _apply_constraints(True) + + Then: All q_out{*} are replaced with q_out3 as below + + +-> q_in1a -> conv1a -> *q_out3* -> concat1 -> *q_out3* -> reshape --+ + +-> q_in1b -> conv1b -> *q_out3* ------^ v + [input] --+ concat3 -> q_out3 -> [output] + +-> q_in2a -> conv2a -> *q_out3* -> concat2 -> *q_out3* -------------^ + +-> q_in2b -> conv2b -> *q_out3* ------^ + """ + with _apply_constraints(True): + sim = QuantizationSimModel(model, dummy_input) + sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) + + for cg_op in sim.connected_graph.ordered_ops: + if cg_op.type in ['Conv', 'Concat']: + _, out_qtzr, __ = sim.get_op_quantizers(cg_op) + assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0]) + + @pytest.mark.parametrize('op_type_under_test', [torch.nn.MaxPool2d, torch.nn.AvgPool2d, torch.nn.Upsample]) + def test_output_parametrized(self, op_type_under_test): + """ + Given: model as below + [input] -+-> q_in1 -> conv1 -> q_out1 -> op_type_under_test -> q_out2 -> [output] + """ + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3,3,3) + self.op_type_under_test = op_type_under_test(3) + def forward(self, x): + x1 = self.conv1(x) + return self.op_type_under_test(x1) + """ + When: _apply_constraints(True) + + Then: q_out1 will be replaced with q_out2 as below + + [input] -+-> q_in1 -> conv1 -> *q_out2* -> op_type_under_test -> q_out2 -> [output] + + """ + pt_model = Model().eval() + x = torch.randn(1, 3, 24, 24) + model = _convert_to_onnx(pt_model, x) + dummy_input = make_dummy_input(model.model) + with _apply_constraints(True): + sim = QuantizationSimModel(model, dummy_input) + sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) + + for cg_op in sim.connected_graph.ordered_ops: + if cg_op.type in ['Conv']: + _, out_qtzr, __ = sim.get_op_quantizers(cg_op) + assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0])