diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index 3fec68551f..82a2923a48 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -36,6 +36,7 @@ # ============================================================================= """ Implementation for simulating models running on Quantized hardware """ +import contextlib import tempfile from dataclasses import dataclass from pathlib import Path @@ -84,8 +85,31 @@ allowed_op_type_for_per_channel = ['Conv', 'Gemm', 'MatMul', 'ConvTranspose'] +# List of op types whose input and output quantizers to be tied +op_types_to_tie_qtzrs = ['Concat', 'MaxPool', 'AveragePool', 'Resize'] +_tie_qtzrs = False + data_types_to_quantize = [np.float32] + +@contextlib.contextmanager +def _apply_constraints(flag: bool): + """ + Apply runtime specific constraints. + For certain ``op_types_to_tie_qtzrs``, runtime has constraints to have same encodings for + input and output quantizers. + + NOTE: Default setting doesn't apply these constraints. + """ + global _tie_qtzrs # pylint: disable=global-statement + orig_flag = _tie_qtzrs + try: + _tie_qtzrs = flag + yield + finally: + _tie_qtzrs = orig_flag + + @dataclass class EncodingMismatchInfo: """ @@ -127,8 +151,7 @@ def __init__(self, use_symmetric_encodings: bool = False, use_cuda: bool = True, device: int = 0, config_file: str = None, default_data_type: QuantizationDataType = QuantizationDataType.int, - simplify_model: bool = True, user_onnx_libs: List[str] = None, - path: str = None, op_types_to_tie: Union[str, Tuple] = None): + simplify_model: bool = True, user_onnx_libs: List[str] = None, path: str = None): """ Constructor @@ -148,7 +171,6 @@ def __init__(self, :param simplify_model: Default True, uses onnx simplifier to simplify model :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries :param path: Directory to save the artifacts. - :param op_types_to_tie: Operator types for which to tie input and output quantizers """ self.model = model if not isinstance(model, ONNXModel): @@ -180,7 +202,6 @@ def __init__(self, else: self._op_domain = "aimet.customop.cpu" self.providers = ['CPUExecutionProvider'] - self._op_types_to_tie = op_types_to_tie self._user_onnx_libs = user_onnx_libs self.param_names = [] self.input_quantizers_name = [] @@ -802,10 +823,9 @@ def _tie_quantizers(self): """ Tie the input and output quantizers for given op types. """ - if not self._op_types_to_tie: + if not _tie_qtzrs: return - op_types_to_tie = self._op_types_to_tie cg = self.connected_graph def _set_quant_info(dst_qtzr_node_name: str, src_qtzr: QcQuantizeOp): @@ -868,11 +888,8 @@ def _set_src_qtzr(x: Product, consumer: Op, src_qtzr): for inp in producer.inputs: _set_src_qtzr(inp, consumer=producer, src_qtzr=src_qtzr) - if isinstance(op_types_to_tie, str): - op_types_to_tie = (op_types_to_tie, ) - for op in reversed(cg.ordered_ops): - if op.type not in op_types_to_tie: + if op.type not in op_types_to_tie_qtzrs: continue _, out_qtzr, __ = self.get_op_quantizers(op) diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index 8b2a58866e..7d90d6efa1 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -52,10 +52,9 @@ from aimet_common import libquant_info from aimet_common.defs import QuantScheme, QuantizationDataType, EncodingType from aimet_common.quantsim_config.utils import get_path_for_per_channel_config -from aimet_onnx.quantsim import QuantizationSimModel, load_encodings_to_sim, set_blockwise_quantization_for_weights +from aimet_onnx.quantsim import QuantizationSimModel, load_encodings_to_sim, set_blockwise_quantization_for_weights, _apply_constraints from aimet_onnx.qc_quantize_op import OpMode from aimet_onnx.utils import make_dummy_input -from aimet_onnx import utils from models.models_for_tests import SingleResidual from models import models_for_tests, test_models from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model , multi_output_model, custom_add_model, build_lstm_gru_dummy_model, \ @@ -1289,7 +1288,6 @@ def __init__(self): self.relu1 = torch.nn.ReLU() self.conv2 = torch.nn.Conv2d(3,3,3) self.relu2 = torch.nn.ReLU() - self.cat = Concat() def forward(self, x): x1 = x2 = x @@ -1297,9 +1295,9 @@ def forward(self, x): x1 = self.relu1(x1) x2 = self.conv2(x2) x2 = self.relu2(x2) - return self.cat(x1, x2) + return torch.cat([x1, x2]) """ - When: op_types_to_tie=('Concat', ) + When: _apply_constraints(True) Then: q_out1 and q_out2 are replaced with q_out3 as below @@ -1311,13 +1309,14 @@ def forward(self, x): x = torch.randn(1, 3, 24, 24) model = _convert_to_onnx(pt_model, x) dummy_input = make_dummy_input(model.model) - sim = QuantizationSimModel(model, dummy_input, op_types_to_tie=('Concat',)) + with _apply_constraints(True): + sim = QuantizationSimModel(model, dummy_input) - sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) - assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0], - sim.qc_quantize_op_dict['output'].encodings[0]) - assert _compare_encodings(sim.qc_quantize_op_dict['/relu2/Relu_output_0'].encodings[0], - sim.qc_quantize_op_dict['output'].encodings[0]) + sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) + assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0], + sim.qc_quantize_op_dict['output'].encodings[0]) + assert _compare_encodings(sim.qc_quantize_op_dict['/relu2/Relu_output_0'].encodings[0], + sim.qc_quantize_op_dict['output'].encodings[0]) def test_math_invariant(self): """ @@ -1332,7 +1331,6 @@ def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(3, 3, 3, padding=1) self.relu1 = torch.nn.ReLU() - self.cat = Concat() def forward(self, x): x1 = x2 = x @@ -1340,9 +1338,9 @@ def forward(self, x): x1 = self.relu1(x1) x2 = torch.reshape(x2, (-1, 24, 24, 3)) x2 = torch.permute(x2, (0, 3, 1, 2)) - return self.cat(x1, x2) + return torch.cat([x1, x2]) """ - When: op_types_to_tie=('Concat', ) + When: _apply_constraints(True) Then: q_out1 and q_in2 are replaced with q_out3 as below @@ -1354,13 +1352,14 @@ def forward(self, x): dummy_input = torch.randn(1, 3, 24, 24) model = _convert_to_onnx(pt_model, dummy_input) dummy_input = make_dummy_input(model.model) - sim = QuantizationSimModel(model, dummy_input, op_types_to_tie=('Concat', )) - sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) + with _apply_constraints(True): + sim = QuantizationSimModel(model, dummy_input) + sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) - assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0], - sim.qc_quantize_op_dict['output'].encodings[0]) - assert _compare_encodings(sim.qc_quantize_op_dict['input'].encodings[0], - sim.qc_quantize_op_dict['output'].encodings[0]) + assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0], + sim.qc_quantize_op_dict['output'].encodings[0]) + assert _compare_encodings(sim.qc_quantize_op_dict['input'].encodings[0], + sim.qc_quantize_op_dict['output'].encodings[0]) def test_concat_tree(self): """ @@ -1397,7 +1396,7 @@ def forward(self, x): model = _convert_to_onnx(pt_model, dummy_input) dummy_input = make_dummy_input(model.model) """ - When: op_types_to_tie=('Concat',) + When: _apply_constraints(True) Then: All q_out{*} are replaced with q_out3 as below @@ -1407,13 +1406,14 @@ def forward(self, x): +-> q_in2a -> conv2a -> *q_out3* -> concat2 -> *q_out3* -------------^ +-> q_in2b -> conv2b -> *q_out3* ------^ """ - sim = QuantizationSimModel(model, dummy_input, op_types_to_tie=('Concat',)) - sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) + with _apply_constraints(True): + sim = QuantizationSimModel(model, dummy_input) + sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) - for cg_op in sim.connected_graph.ordered_ops: - if cg_op.type in ['Conv', 'Concat']: - _, out_qtzr, __ = sim.get_op_quantizers(cg_op) - assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0]) + for cg_op in sim.connected_graph.ordered_ops: + if cg_op.type in ['Conv', 'Concat']: + _, out_qtzr, __ = sim.get_op_quantizers(cg_op) + assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0]) @pytest.mark.parametrize('op_type_under_test', [torch.nn.MaxPool2d, torch.nn.AvgPool2d, torch.nn.Upsample]) def test_output_parametrized(self, op_type_under_test): @@ -1430,7 +1430,7 @@ def forward(self, x): x1 = self.conv1(x) return self.op_type_under_test(x1) """ - When: op_types_to_tie=('op_type_under_test',) + When: _apply_constraints(True) Then: q_out1 will be replaced with q_out2 as below @@ -1441,19 +1441,11 @@ def forward(self, x): x = torch.randn(1, 3, 24, 24) model = _convert_to_onnx(pt_model, x) dummy_input = make_dummy_input(model.model) - if isinstance(pt_model.op_type_under_test, torch.nn.MaxPool2d): - op_type = "MaxPool" - elif isinstance(pt_model.op_type_under_test, torch.nn.AvgPool2d): - op_type = "AveragePool" - elif isinstance(pt_model.op_type_under_test, torch.nn.Upsample): - op_type = "Resize" - else: - raise ValueError(f"Unsupported op_type") - - sim = QuantizationSimModel(model, dummy_input, op_types_to_tie=op_type) - sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) + with _apply_constraints(True): + sim = QuantizationSimModel(model, dummy_input) + sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None) - for cg_op in sim.connected_graph.ordered_ops: - if cg_op.type in ['Conv']: - _, out_qtzr, __ = sim.get_op_quantizers(cg_op) - assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0]) + for cg_op in sim.connected_graph.ordered_ops: + if cg_op.type in ['Conv']: + _, out_qtzr, __ = sim.get_op_quantizers(cg_op) + assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0])