Skip to content

Commit

Permalink
Expose temp API to apply constraints
Browse files Browse the repository at this point in the history
Signed-off-by: Hitarth Mehta <[email protected]>
  • Loading branch information
quic-hitameht committed Oct 10, 2024
1 parent fe43230 commit 138f8c8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 53 deletions.
37 changes: 27 additions & 10 deletions TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# =============================================================================
""" Implementation for simulating models running on Quantized hardware """

import contextlib
import tempfile
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -84,8 +85,31 @@

allowed_op_type_for_per_channel = ['Conv', 'Gemm', 'MatMul', 'ConvTranspose']

# List of op types whose input and output quantizers to be tied
op_types_to_tie_qtzrs = ['Concat', 'MaxPool', 'AveragePool', 'Resize']
_tie_qtzrs = False

data_types_to_quantize = [np.float32]


@contextlib.contextmanager
def _apply_constraints(flag: bool):
"""
Apply runtime specific constraints.
For certain ``op_types_to_tie_qtzrs``, runtime has constraints to have same encodings for
input and output quantizers.
NOTE: Default setting doesn't apply these constraints.
"""
global _tie_qtzrs # pylint: disable=global-statement
orig_flag = _tie_qtzrs
try:
_tie_qtzrs = flag
yield
finally:
_tie_qtzrs = orig_flag


@dataclass
class EncodingMismatchInfo:
"""
Expand Down Expand Up @@ -127,8 +151,7 @@ def __init__(self,
use_symmetric_encodings: bool = False, use_cuda: bool = True,
device: int = 0, config_file: str = None,
default_data_type: QuantizationDataType = QuantizationDataType.int,
simplify_model: bool = True, user_onnx_libs: List[str] = None,
path: str = None, op_types_to_tie: Union[str, Tuple] = None):
simplify_model: bool = True, user_onnx_libs: List[str] = None, path: str = None):
"""
Constructor
Expand All @@ -148,7 +171,6 @@ def __init__(self,
:param simplify_model: Default True, uses onnx simplifier to simplify model
:param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
:param path: Directory to save the artifacts.
:param op_types_to_tie: Operator types for which to tie input and output quantizers
"""
self.model = model
if not isinstance(model, ONNXModel):
Expand Down Expand Up @@ -180,7 +202,6 @@ def __init__(self,
else:
self._op_domain = "aimet.customop.cpu"
self.providers = ['CPUExecutionProvider']
self._op_types_to_tie = op_types_to_tie
self._user_onnx_libs = user_onnx_libs
self.param_names = []
self.input_quantizers_name = []
Expand Down Expand Up @@ -802,10 +823,9 @@ def _tie_quantizers(self):
"""
Tie the input and output quantizers for given op types.
"""
if not self._op_types_to_tie:
if not _tie_qtzrs:
return

op_types_to_tie = self._op_types_to_tie
cg = self.connected_graph

def _set_quant_info(dst_qtzr_node_name: str, src_qtzr: QcQuantizeOp):
Expand Down Expand Up @@ -868,11 +888,8 @@ def _set_src_qtzr(x: Product, consumer: Op, src_qtzr):
for inp in producer.inputs:
_set_src_qtzr(inp, consumer=producer, src_qtzr=src_qtzr)

if isinstance(op_types_to_tie, str):
op_types_to_tie = (op_types_to_tie, )

for op in reversed(cg.ordered_ops):
if op.type not in op_types_to_tie:
if op.type not in op_types_to_tie_qtzrs:
continue

_, out_qtzr, __ = self.get_op_quantizers(op)
Expand Down
78 changes: 35 additions & 43 deletions TrainingExtensions/onnx/test/python/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@
from aimet_common import libquant_info
from aimet_common.defs import QuantScheme, QuantizationDataType, EncodingType
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_onnx.quantsim import QuantizationSimModel, load_encodings_to_sim, set_blockwise_quantization_for_weights
from aimet_onnx.quantsim import QuantizationSimModel, load_encodings_to_sim, set_blockwise_quantization_for_weights, _apply_constraints
from aimet_onnx.qc_quantize_op import OpMode
from aimet_onnx.utils import make_dummy_input
from aimet_onnx import utils
from models.models_for_tests import SingleResidual
from models import models_for_tests, test_models
from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model , multi_output_model, custom_add_model, build_lstm_gru_dummy_model, \
Expand Down Expand Up @@ -1289,17 +1288,16 @@ def __init__(self):
self.relu1 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(3,3,3)
self.relu2 = torch.nn.ReLU()
self.cat = Concat()

def forward(self, x):
x1 = x2 = x
x1 = self.conv1(x1)
x1 = self.relu1(x1)
x2 = self.conv2(x2)
x2 = self.relu2(x2)
return self.cat(x1, x2)
return torch.cat([x1, x2])
"""
When: op_types_to_tie=('Concat', )
When: _apply_constraints(True)
Then: q_out1 and q_out2 are replaced with q_out3 as below
Expand All @@ -1311,13 +1309,14 @@ def forward(self, x):
x = torch.randn(1, 3, 24, 24)
model = _convert_to_onnx(pt_model, x)
dummy_input = make_dummy_input(model.model)
sim = QuantizationSimModel(model, dummy_input, op_types_to_tie=('Concat',))
with _apply_constraints(True):
sim = QuantizationSimModel(model, dummy_input)

sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None)
assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0],
sim.qc_quantize_op_dict['output'].encodings[0])
assert _compare_encodings(sim.qc_quantize_op_dict['/relu2/Relu_output_0'].encodings[0],
sim.qc_quantize_op_dict['output'].encodings[0])
sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None)
assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0],
sim.qc_quantize_op_dict['output'].encodings[0])
assert _compare_encodings(sim.qc_quantize_op_dict['/relu2/Relu_output_0'].encodings[0],
sim.qc_quantize_op_dict['output'].encodings[0])

def test_math_invariant(self):
"""
Expand All @@ -1332,17 +1331,16 @@ def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3, padding=1)
self.relu1 = torch.nn.ReLU()
self.cat = Concat()

def forward(self, x):
x1 = x2 = x
x1 = self.conv1(x1)
x1 = self.relu1(x1)
x2 = torch.reshape(x2, (-1, 24, 24, 3))
x2 = torch.permute(x2, (0, 3, 1, 2))
return self.cat(x1, x2)
return torch.cat([x1, x2])
"""
When: op_types_to_tie=('Concat', )
When: _apply_constraints(True)
Then: q_out1 and q_in2 are replaced with q_out3 as below
Expand All @@ -1354,13 +1352,14 @@ def forward(self, x):
dummy_input = torch.randn(1, 3, 24, 24)
model = _convert_to_onnx(pt_model, dummy_input)
dummy_input = make_dummy_input(model.model)
sim = QuantizationSimModel(model, dummy_input, op_types_to_tie=('Concat', ))
sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None)
with _apply_constraints(True):
sim = QuantizationSimModel(model, dummy_input)
sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None)

assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0],
sim.qc_quantize_op_dict['output'].encodings[0])
assert _compare_encodings(sim.qc_quantize_op_dict['input'].encodings[0],
sim.qc_quantize_op_dict['output'].encodings[0])
assert _compare_encodings(sim.qc_quantize_op_dict['/relu1/Relu_output_0'].encodings[0],
sim.qc_quantize_op_dict['output'].encodings[0])
assert _compare_encodings(sim.qc_quantize_op_dict['input'].encodings[0],
sim.qc_quantize_op_dict['output'].encodings[0])

def test_concat_tree(self):
"""
Expand Down Expand Up @@ -1397,7 +1396,7 @@ def forward(self, x):
model = _convert_to_onnx(pt_model, dummy_input)
dummy_input = make_dummy_input(model.model)
"""
When: op_types_to_tie=('Concat',)
When: _apply_constraints(True)
Then: All q_out{*} are replaced with q_out3 as below
Expand All @@ -1407,13 +1406,14 @@ def forward(self, x):
+-> q_in2a -> conv2a -> *q_out3* -> concat2 -> *q_out3* -------------^
+-> q_in2b -> conv2b -> *q_out3* ------^
"""
sim = QuantizationSimModel(model, dummy_input, op_types_to_tie=('Concat',))
sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None)
with _apply_constraints(True):
sim = QuantizationSimModel(model, dummy_input)
sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None)

for cg_op in sim.connected_graph.ordered_ops:
if cg_op.type in ['Conv', 'Concat']:
_, out_qtzr, __ = sim.get_op_quantizers(cg_op)
assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0])
for cg_op in sim.connected_graph.ordered_ops:
if cg_op.type in ['Conv', 'Concat']:
_, out_qtzr, __ = sim.get_op_quantizers(cg_op)
assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0])

@pytest.mark.parametrize('op_type_under_test', [torch.nn.MaxPool2d, torch.nn.AvgPool2d, torch.nn.Upsample])
def test_output_parametrized(self, op_type_under_test):
Expand All @@ -1430,7 +1430,7 @@ def forward(self, x):
x1 = self.conv1(x)
return self.op_type_under_test(x1)
"""
When: op_types_to_tie=('op_type_under_test',)
When: _apply_constraints(True)
Then: q_out1 will be replaced with q_out2 as below
Expand All @@ -1441,19 +1441,11 @@ def forward(self, x):
x = torch.randn(1, 3, 24, 24)
model = _convert_to_onnx(pt_model, x)
dummy_input = make_dummy_input(model.model)
if isinstance(pt_model.op_type_under_test, torch.nn.MaxPool2d):
op_type = "MaxPool"
elif isinstance(pt_model.op_type_under_test, torch.nn.AvgPool2d):
op_type = "AveragePool"
elif isinstance(pt_model.op_type_under_test, torch.nn.Upsample):
op_type = "Resize"
else:
raise ValueError(f"Unsupported op_type")

sim = QuantizationSimModel(model, dummy_input, op_types_to_tie=op_type)
sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None)
with _apply_constraints(True):
sim = QuantizationSimModel(model, dummy_input)
sim.compute_encodings(lambda session, _: session.run(None, dummy_input), None)

for cg_op in sim.connected_graph.ordered_ops:
if cg_op.type in ['Conv']:
_, out_qtzr, __ = sim.get_op_quantizers(cg_op)
assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0])
for cg_op in sim.connected_graph.ordered_ops:
if cg_op.type in ['Conv']:
_, out_qtzr, __ = sim.get_op_quantizers(cg_op)
assert _compare_encodings(out_qtzr[0].encodings[0], sim.qc_quantize_op_dict['output'].encodings[0])

0 comments on commit 138f8c8

Please sign in to comment.