Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tie observers utility in aimet onnx #3387

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 123 additions & 7 deletions TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# =============================================================================
""" Implementation for simulating models running on Quantized hardware """

import contextlib
import tempfile
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -58,6 +59,7 @@
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.quantsim import extract_global_quantizer_args, VALID_ENCODING_VERSIONS
from aimet_common.utils import save_json_yaml, AimetLogger
from aimet_common.connected_graph.product import Product
from aimet_onnx import utils
from aimet_onnx.meta.operations import Op
from aimet_onnx.meta.utils import get_op_given_param_name, get_param_shape_using_connected_graph
Expand All @@ -83,8 +85,31 @@

allowed_op_type_for_per_channel = ['Conv', 'Gemm', 'MatMul', 'ConvTranspose']

# List of op types whose input and output quantizers to be tied
op_types_to_tie_qtzrs = ['Concat', 'MaxPool', 'AveragePool', 'Resize']
_tie_qtzrs = False

data_types_to_quantize = [np.float32]


@contextlib.contextmanager
def _apply_constraints(flag: bool):
"""
Apply runtime specific constraints.
For certain ``op_types_to_tie_qtzrs``, runtime has constraints to have same encodings for
input and output quantizers.

NOTE: Default setting doesn't apply these constraints.
"""
global _tie_qtzrs # pylint: disable=global-statement
orig_flag = _tie_qtzrs
try:
_tie_qtzrs = flag
yield
finally:
_tie_qtzrs = orig_flag


@dataclass
class EncodingMismatchInfo:
"""
Expand Down Expand Up @@ -185,25 +210,27 @@ def __init__(self,
self._path = path if path else tempfile.mkdtemp()
if not os.path.exists(self._path):
os.makedirs(self._path, exist_ok=True)

# Get names of parameters and activations to quantize
self._get_param_names()
self._get_activations_to_quantize(dummy_input)

# Disable bias quantization
self._disable_bias_quantization()

self._add_quantization_nodes()

self.session = QuantizationSimModel.build_session(self.model.model, self.providers,
user_onnx_libs=self._user_onnx_libs, path=self._path)
# Apply configurations based on provided config file.
quantsim_configurator = self._add_configuration_(config_file)

self._hw_version = quantsim_configurator._get_hw_version()
self._supported_kernels = quantsim_configurator.get_supported_kernels()
self._op_to_supported_kernel = quantsim_configurator.get_op_to_supported_kernels()

self.quant_args = extract_global_quantizer_args(quant_scheme, quantsim_configurator)

self._apply_exception_rules()
self._tie_quantizers()

# Build onnxruntime inference session
self.session = QuantizationSimModel.build_session(self.model.model, self.providers,
user_onnx_libs=self._user_onnx_libs, path=self._path)

def get_supported_kernels(self) -> Dict:
"""
Expand Down Expand Up @@ -548,7 +575,7 @@ def get_op_quantizers(self, op: Op) -> (List, List, Dict):
if param_name in self.qc_quantize_op_dict:
param_quantizers[param_type] = self.qc_quantize_op_dict[param_name]

return (input_quantizers, output_quantizers, param_quantizers)
return input_quantizers, output_quantizers, param_quantizers

def _apply_exception_rules(self):
"""
Expand Down Expand Up @@ -792,6 +819,95 @@ def get_all_quantizers(self) -> Tuple[List, List]:

return param_quantizers, activation_quantizers

def _tie_quantizers(self):
"""
Tie the input and output quantizers for given op types.
"""
if not _tie_qtzrs:
return

cg = self.connected_graph

def _set_quant_info(dst_qtzr_node_name: str, src_qtzr: QcQuantizeOp):
"""
Set quant_info attribute (pointer to the libquant_info object)

:param dst_qtzr_node_name: destination quantizer node name in graph.
:param src_qtzr: source quantizer.
"""
for node in self.model.graph().node:
if node.op_type == 'QcQuantizeOp' and node.name == dst_qtzr_node_name:
for atr in node.attribute:
if atr.name == "quant_info":
atr.i = libpymo.PtrToInt64(src_qtzr.quant_info)
return

def _set_qtzr(dst_qtzr: QcQuantizeOp, src_qtzr: QcQuantizeOp):
"""
Set the dst quantizer by src quantizer and update quant_info attribute (pointer to the libquant_info object)
in the graph node.

:param dst_qtzr: destination quantizer.
:param src_qtzr: source quantizer
"""
for name, qtzr in self.qc_quantize_op_dict.items():
if dst_qtzr == qtzr:
self.qc_quantize_op_dict[name] = src_qtzr
dst_qtzr_node_name = 'QcQuantizeOp_' + name
# update quant_info attribute (pointer to the libquant_info object) in the graph node.
_set_quant_info(dst_qtzr_node_name, src_qtzr)
return

def _set_src_qtzr(x: Product, consumer: Op, src_qtzr):
producer = x.producer

if not producer:
# ``x`` is a root input (i.e. has no producer).
# In this case, set the input quantizer of the consumer to ``src_qtzr``
i = consumer.inputs.index(x)
inp_qtzr, _, __ = self.get_op_quantizers(consumer)
if i >= len(inp_qtzr):
return

_set_qtzr(dst_qtzr=inp_qtzr[i], src_qtzr=src_qtzr)
return

_, out_qtzr, __ = self.get_op_quantizers(producer)

if out_qtzr:
# There exists output quantizer associated with the graph node ``producer``
# In this case, set the output quantizer of the producer to ``src_qtzr`
outputs = [producer.output]
i = outputs.index(x)
_set_qtzr(dst_qtzr=out_qtzr[i], src_qtzr=src_qtzr)

if not out_qtzr or producer.type in op_outputs_to_ignore:
# 1. There is no output quantizer associated with the graph node ``producer``, or
# 2. op is a math invariant op (reshape, permute, etc.).
# In these cases, propagate encoding further to the ancestors
for inp in producer.inputs:
_set_src_qtzr(inp, consumer=producer, src_qtzr=src_qtzr)

for op in reversed(cg.ordered_ops):
if op.type not in op_types_to_tie_qtzrs:
continue

_, out_qtzr, __ = self.get_op_quantizers(op)

if not out_qtzr:
msg = 'Encoding propagation is only supported for ops with exactly ' \
'1 output quantizer, but found output_quantizers[0] == []'
raise RuntimeError(msg)

if len(out_qtzr) != 1:
msg = 'Encoding propagation is only supported for ops with exactly ' \
f'1 output quantizer, but found {len(out_qtzr)} ' \
'output quantizers'
raise RuntimeError(msg)

for inp in op.inputs:
_set_src_qtzr(inp, consumer=op, src_qtzr=out_qtzr[0])


def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_path: str, strict=True) -> \
List[EncodingMismatchInfo]:
Expand Down
48 changes: 26 additions & 22 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,31 +1634,35 @@ def forward(self, inputs):
return x, y


def _convert_to_onnx_no_fold(model: torch.nn.Module, dummy_input, filename='./temp_model.onnx'):
torch.onnx.export(model.eval(),
dummy_input,
filename,
training=torch.onnx.TrainingMode.PRESERVE,
export_params=True,
opset_version=12,
do_constant_folding=False,
input_names=['input'],
output_names=['output'])
model = ONNXModel(load_model(filename))
def _convert_to_onnx_no_fold(model: torch.nn.Module, dummy_input, filename='temp_model.onnx'):
with tempfile.TemporaryDirectory() as tmp_dir:
save_path = os.path.join(tmp_dir, filename)
torch.onnx.export(model.eval(),
dummy_input,
save_path,
training=torch.onnx.TrainingMode.PRESERVE,
export_params=True,
opset_version=12,
do_constant_folding=False,
input_names=['input'],
output_names=['output'])
model = ONNXModel(load_model(save_path))
return model


def _convert_to_onnx(model: torch.nn.Module, dummy_input, filename='./temp_model.onnx'):
torch.onnx.export(model.eval(),
dummy_input,
filename,
training=torch.onnx.TrainingMode.EVAL,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
model = ONNXModel(load_model(filename))
def _convert_to_onnx(model: torch.nn.Module, dummy_input, filename='temp_model.onnx'):
with tempfile.TemporaryDirectory() as tmp_dir:
save_path = os.path.join(tmp_dir, filename)
torch.onnx.export(model.eval(),
dummy_input,
save_path,
training=torch.onnx.TrainingMode.EVAL,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
model = ONNXModel(load_model(save_path))
return model


Expand Down
Loading
Loading