diff --git a/TrainingExtensions/common/src/python/aimet_common/amp/convert_ops_reduction.py b/TrainingExtensions/common/src/python/aimet_common/amp/convert_ops_reduction.py index 95ccfff8548..567ab8ad0e1 100644 --- a/TrainingExtensions/common/src/python/aimet_common/amp/convert_ops_reduction.py +++ b/TrainingExtensions/common/src/python/aimet_common/amp/convert_ops_reduction.py @@ -148,8 +148,8 @@ def rename_nodes(G: nx.DiGraph, module_name_to_module_dict: Dict, dotted_name2op for node in G_copy.nodes: if node not in ["input_ops", "output_ops"]: if ("input_ops", node) in G_copy.edges: # quantizer groups with an input quantizer - try: - module_name, _ = find_wrapper_module(node, module_name_to_module_dict) + module_name, _ = find_wrapper_module(node, module_name_to_module_dict) + if module_name is not None: mapping = {node: module_name + "_output"} G = nx.relabel_nodes(G, mapping) @@ -173,14 +173,14 @@ def rename_nodes(G: nx.DiGraph, module_name_to_module_dict: Dict, dotted_name2op G.nodes[new_node_name]["tensor_dims"] = input_shape G.nodes[new_node_name]["tensor_size"] = input_size - except: + else: _logger.info("did not change node name: %s", node) else: - try: - module_name, _ = find_wrapper_module(node, module_name_to_module_dict) + module_name, _ = find_wrapper_module(node, module_name_to_module_dict) + if module_name is not None: mapping = {node: module_name + "_output"} G = nx.relabel_nodes(G, mapping) - except: + else: _logger.info("did not change node name: %s", node) G.remove_node("input_ops") diff --git a/TrainingExtensions/common/src/python/aimet_common/amp/quantizer_groups.py b/TrainingExtensions/common/src/python/aimet_common/amp/quantizer_groups.py index 9da9e732b62..d1e9dc842a2 100644 --- a/TrainingExtensions/common/src/python/aimet_common/amp/quantizer_groups.py +++ b/TrainingExtensions/common/src/python/aimet_common/amp/quantizer_groups.py @@ -199,7 +199,6 @@ def get_supported_candidates_for_quantizers(quantizers: List, # Store candidates for quantizer store_candidates_for_quantizer(supported_kernels, op, amp_candidates_set, act_bw_set, act_and_param_set, act_only_set, null_intersection_ops) - break # Default candidate selected if op not found in supported kernels if not ops_found: diff --git a/TrainingExtensions/torch/src/python/aimet_torch/amp/quantizer_groups.py b/TrainingExtensions/torch/src/python/aimet_torch/amp/quantizer_groups.py index 247c3ca857d..253bbea58b2 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/amp/quantizer_groups.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/amp/quantizer_groups.py @@ -37,7 +37,7 @@ """ Find quantizer groups in a model """ import itertools -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional from collections import defaultdict from dataclasses import dataclass, field import torch @@ -66,6 +66,7 @@ class QuantizerGroup(QuantizerGroupBase): input_quantizers: Tuple[str, ...] = field(default_factory=tuple) output_quantizers: Tuple[str, ...] = field(default_factory=tuple) parameter_quantizers: Tuple[str, ...] = field(default_factory=tuple) + supported_kernel_ops: Tuple[str, ...] = field(default_factory=tuple) def get_candidate(self, name_to_quantizer_dict: Dict) -> CANDIDATE_WITH_DTYPE: """ @@ -167,7 +168,8 @@ def get_input_quantizer_modules(self): return tuple(sorted(result)) -def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> Tuple[str, torch.nn.Module]: +def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> \ + Tuple[Optional[str], Optional[torch.nn.Module]]: """ Finds quantization (wrapping) module corresponding to the wrapper module's dotted name :param op_name: Dotted name of op as represented in connected graph @@ -179,7 +181,7 @@ def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> Tu if module_name in module_name_to_quantizer_dict: return module_name, module_name_to_quantizer_dict[module_name] # Else it is a functional op - raise KeyError + return None, None def get_module_name_to_module_dict(sim: QuantizationSimModel) -> Dict: @@ -320,11 +322,8 @@ def get_input_and_param_quantizers( """ input_quantizers = [] parameter_quantizers = [] - try: - module_name, module = find_wrapper_module(child, module_name_to_module_dict) - except KeyError: - pass - else: + module_name, module = find_wrapper_module(child, module_name_to_module_dict) + if module_name is not None: for idx, input_quantizer in enumerate(module.input_quantizers): if input_quantizer.enabled: input_quantizers.append(module_name + '_input_quantizer_idx_' + str(idx)) @@ -358,9 +357,14 @@ def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[Quantize # Add one quantizer group for each input and it's weight param input_quantizers, parameter_quantizers = get_input_and_param_quantizers(child, module_name_to_module_dict) if input_quantizers or parameter_quantizers: + child_module_name, _ = find_wrapper_module(child, module_name_to_module_dict) + supported_kernel_ops = [] + if child_module_name is not None: + supported_kernel_ops.append(child_module_name) quantizer_group = QuantizerGroup( input_quantizers=input_quantizers, - parameter_quantizers=parameter_quantizers + parameter_quantizers=parameter_quantizers, + supported_kernel_ops=tuple(supported_kernel_ops) ) quantizer_groups.append(quantizer_group) logger.debug('\n Quantizer Group Added: %s', quantizer_group) @@ -375,26 +379,28 @@ def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[Quantize if not isinstance(parents, tuple): parents = [parents] for parent in parents: - try: - module_name, module = find_wrapper_module(parent, module_name_to_module_dict) - except KeyError: - continue + module_name, module = find_wrapper_module(parent, module_name_to_module_dict) if module is not None: for output_quantizer in module.output_quantizers: if output_quantizer.enabled: output_quantizers += (module_name,) + supported_kernel_ops = [] for child in children: input_q, param_q = get_input_and_param_quantizers(child, module_name_to_module_dict) input_quantizers += input_q parameter_quantizers += param_q + child_module_name, _ = find_wrapper_module(child, module_name_to_module_dict) + if child_module_name is not None: + supported_kernel_ops.append(child_module_name) # Don't add quantizer group if it is empty if input_quantizers or output_quantizers or parameter_quantizers: quantizer_group = QuantizerGroup( input_quantizers=input_quantizers, output_quantizers=output_quantizers, - parameter_quantizers=parameter_quantizers + parameter_quantizers=parameter_quantizers, + supported_kernel_ops=tuple(supported_kernel_ops) ) quantizer_groups.append(quantizer_group) logger.debug('\n Quantizer Group added: %s', quantizer_group) @@ -402,15 +408,15 @@ def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[Quantize if 'output_ops' in parent_child_op_groups: for parent in parent_child_op_groups['output_ops']: # Add one quantizer group for each input and it's weight param - try: - module_name, module = find_wrapper_module(parent, module_name_to_module_dict) - except KeyError: - continue + module_name, module = find_wrapper_module(parent, module_name_to_module_dict) if module is not None: for output_quantizer in module.output_quantizers: if output_quantizer.enabled: + # Using empty supported kernel ops so that model output quantizers are able to consider all + # default candidates quantizer_group = QuantizerGroup( output_quantizers=(module_name,), + supported_kernel_ops=tuple() ) quantizer_groups.append(quantizer_group) logger.debug('\n Quantizer Group added: %s', quantizer_group) @@ -437,6 +443,7 @@ def find_supported_candidates(quantizer_groups: List[QuantizerGroup], quantizers_with_supported_candidates = defaultdict(list) # pylint: disable=too-many-nested-blocks + # pylint: disable=protected-access for quantizer_group in quantizer_groups: quantizers = sorted(set(itertools.chain(quantizer_group.get_input_quantizer_modules(), quantizer_group.output_quantizers, @@ -444,35 +451,34 @@ def find_supported_candidates(quantizer_groups: List[QuantizerGroup], # quantizers are now unique ops present in the given quantizer_group onnx_ops = defaultdict(list) - for quantizer in quantizers: - if quantizer not in module_name_to_module_dict: - raise RuntimeError('module_name_to_module_dict does not contain an entry for the quantizer:', - quantizer) - - # pylint: disable=protected-access - module = module_name_to_module_dict[quantizer]._module_to_wrap - + supported_kernel_types = set() + for supported_kernel_op in quantizer_group.supported_kernel_ops: + module = module_name_to_module_dict[supported_kernel_op]._module_to_wrap try: backend_type = aimet_op_to_backend_op_name_map[module.__class__] except KeyError: backend_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__) if backend_type in supported_kernels: - onnx_ops[quantizer] = [backend_type] + supported_kernel_types.add(backend_type) else: onnx_types = onnx_utils.map_torch_types_to_onnx.get( - type(module_name_to_module_dict[quantizer]._module_to_wrap), []) + type(module_name_to_module_dict[supported_kernel_op]._module_to_wrap), []) if not onnx_types: logger.warning("No mapping found for %s in the torch to onnx op type mapping dictionary.", - str(type(module_name_to_module_dict[quantizer]._module_to_wrap))) + str(type(module_name_to_module_dict[supported_kernel_op]._module_to_wrap))) + + supported_kernel_types.update(onnx_types) - onnx_ops[quantizer] = onnx_types for onnx_type in onnx_types: if onnx_type not in supported_kernels.keys(): if module in supported_kernels: supported_kernels[onnx_type] = supported_kernels[module] + for quantizer in quantizers: + onnx_ops[quantizer] = list(supported_kernel_types) + supported_kernels_for_quantizers = get_supported_candidates_for_quantizers(quantizers, onnx_ops, supported_kernels, diff --git a/TrainingExtensions/torch/test/python/test_mixed_precision.py b/TrainingExtensions/torch/test/python/test_mixed_precision.py index e4cb5df7e24..ace2f3dce3d 100644 --- a/TrainingExtensions/torch/test/python/test_mixed_precision.py +++ b/TrainingExtensions/torch/test/python/test_mixed_precision.py @@ -1111,19 +1111,17 @@ def test_supported_candidates_2( # default_supported_kernels and conv_supported_kernels are the configurations added in the json file above. default_supported_kernels = [((16, QuantizationDataType.int), (16, QuantizationDataType.int)), ((16, QuantizationDataType.float), (16, QuantizationDataType.float)), - ((8, QuantizationDataType.float), (16, QuantizationDataType.float))] + ((8, QuantizationDataType.int), (16, QuantizationDataType.int))] conv_supported_kernels = [((16, QuantizationDataType.float), (16, QuantizationDataType.float)), ((8, QuantizationDataType.int), (16, QuantizationDataType.int))] for quantizer_group, quantizer_candidates in algo._supported_candidates_per_quantizer_group.items(): - quantizers = sorted(itertools.chain(quantizer_group.get_input_quantizer_modules(), - quantizer_group.output_quantizers, - quantizer_group.parameter_quantizers)) + supported_kernel_ops = quantizer_group.supported_kernel_ops onnx_types = [] - for q in quantizers: + for op in supported_kernel_ops: onnx_types.append( - onnx_utils.map_torch_types_to_onnx.get(type(algo._module_name_dict[q]._module_to_wrap))) + onnx_utils.map_torch_types_to_onnx.get(type(algo._module_name_dict[op]._module_to_wrap))) # verify to make sure the candidates returned is always part of amp_candidates and they are part of # "Defaults" or "Conv" appropriately diff --git a/TrainingExtensions/torch/test/python/test_quantizer_groups.py b/TrainingExtensions/torch/test/python/test_quantizer_groups.py index fe11ad8574f..21d9bcec202 100644 --- a/TrainingExtensions/torch/test/python/test_quantizer_groups.py +++ b/TrainingExtensions/torch/test/python/test_quantizer_groups.py @@ -43,10 +43,12 @@ from aimet_torch.batch_norm_fold import fold_all_batch_norms from aimet_torch.examples.test_models import SingleResidual, ConcatModel from aimet_torch.quantsim import QuantizationSimModel -from aimet_torch.amp.quantizer_groups import find_quantizer_group, find_op_groups, find_supported_candidates +from aimet_torch.amp.quantizer_groups import find_quantizer_group, find_op_groups, find_supported_candidates, \ + QuantizerGroup from aimet_torch import utils from aimet_torch.meta.connectedgraph import ConnectedGraph from aimet_torch import onnx_utils +from aimet_torch.v1.nn.modules import custom from torchvision.models import mobilenet_v3_large as mobilenetv3 from models import test_models @@ -353,13 +355,11 @@ def test_find_supported_candidates_5(self): assert ((16, QuantizationDataType.float), (16, QuantizationDataType.float)) in max_candidate_options for quantizer_group, candidates in quantizer_groups_with_supported_candidates.items(): - quantizers = sorted(set(quantizer_group.get_input_quantizer_modules() + - quantizer_group.output_quantizers + - quantizer_group.parameter_quantizers)) + supported_kernel_ops = quantizer_group.supported_kernel_ops onnx_types = [] - for quantizer in quantizers: + for op in supported_kernel_ops: onnx_types.append( - onnx_utils.map_torch_types_to_onnx.get(type(module_name_to_module_dict[quantizer]._module_to_wrap))) + onnx_utils.map_torch_types_to_onnx.get(type(module_name_to_module_dict[op]._module_to_wrap))) # verify to make sure the candidates returned is always part of amp_candidates and they are part of # either "Conv" or "Defaults" @@ -408,3 +408,101 @@ def test_quantizer_groups_with_diff_combinations(self): sim = QuantizationSimModel(model, dummy_input=dummy_input) _, quantizer_groups = find_quantizer_group(sim) assert len(quantizer_groups) == 5 + + def test_supported_kernel_ops(self): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.relu1 = torch.nn.ReLU() + self.relu2 = torch.nn.ReLU() + self.add = custom.Add() + self.relu3 = torch.nn.ReLU() + self.relu4 = torch.nn.ReLU() + + def forward(self, inp, inp2): + x1 = self.relu1(inp) + x2 = self.relu2(inp2) + x = self.add(x1, x2) + x1 = self.relu3(x) + x2 = self.relu4(x) + return x1, x2 + + model = Model() + dummy_input = (torch.randn(1, 3), torch.randn(1, 3)) + sim = QuantizationSimModel(model, dummy_input=dummy_input) + _, quantizer_groups = find_quantizer_group(sim) + assert len(quantizer_groups) == 7 + + expected_groups = [QuantizerGroup(input_quantizers=('relu1_input_quantizer_idx_0',), + supported_kernel_ops=('relu1',)), + QuantizerGroup(input_quantizers=('relu2_input_quantizer_idx_0',), + supported_kernel_ops=('relu2',)), + QuantizerGroup(output_quantizers=('relu1',), + supported_kernel_ops=('add',)), + QuantizerGroup(output_quantizers=('relu2',), + supported_kernel_ops=('add',)), + QuantizerGroup(output_quantizers=('add',), + supported_kernel_ops=('relu3', 'relu4')), + QuantizerGroup(output_quantizers=('relu3',), + supported_kernel_ops=tuple()), + QuantizerGroup(output_quantizers=('relu4',), + supported_kernel_ops=tuple())] + for group in expected_groups: + assert group in quantizer_groups + + def test_find_supported_kernels(self): + quantizer_groups_to_test = [QuantizerGroup(input_quantizers=('inp1_input_quantizer_idx_0', 'inp2_input_quantizer_idx_0'), + supported_kernel_ops=tuple()), + QuantizerGroup(input_quantizers=('inp1_input_quantizer_idx_0', 'inp2_input_quantizer_idx_0'), + supported_kernel_ops=('op1',)), + QuantizerGroup(input_quantizers=('inp1_input_quantizer_idx_0', 'inp2_input_quantizer_idx_0'), + supported_kernel_ops=('op1', 'op2')), + QuantizerGroup(output_quantizers=('inp1', 'inp2'), + supported_kernel_ops=('op1', 'op2')), + QuantizerGroup(parameter_quantizers=('inp1', 'inp2'), + supported_kernel_ops=('op1', 'op2'))] + amp_candidates = [((2, QuantizationDataType.int), (2, QuantizationDataType.int)), + ((3, QuantizationDataType.int), (3, QuantizationDataType.int)), + ((4, QuantizationDataType.int), (4, QuantizationDataType.int)), + ((5, QuantizationDataType.int), (5, QuantizationDataType.int))] + supported_kernels = { + 'defaults': [((2, QuantizationDataType.int), (2, QuantizationDataType.int)), + ((3, QuantizationDataType.int), (3, QuantizationDataType.int)), + ((4, QuantizationDataType.int), (4, QuantizationDataType.int))], + 'Conv': [((2, QuantizationDataType.int), (2, QuantizationDataType.int)), + ((3, QuantizationDataType.int), (3, QuantizationDataType.int))], + 'Relu': [((2, QuantizationDataType.int), (2, QuantizationDataType.int))], + } + + class MockWrapper: + def __init__(self, module_to_wrap): + self._module_to_wrap = module_to_wrap + + module_name_to_module_dict = { + 'op1': MockWrapper(torch.nn.Conv2d(3, 8, (2, 2))), + 'op2': MockWrapper(torch.nn.ReLU()) + } + + supported_kernel_dict, _ = find_supported_candidates(quantizer_groups_to_test, + amp_candidates, + supported_kernels, + module_name_to_module_dict, + False) + + assert set(supported_kernel_dict[quantizer_groups_to_test[0]]) == { + ((2, QuantizationDataType.int), (2, QuantizationDataType.int)), + ((3, QuantizationDataType.int), (3, QuantizationDataType.int)), + ((4, QuantizationDataType.int), (4, QuantizationDataType.int))} + + assert set(supported_kernel_dict[quantizer_groups_to_test[1]]) == { + ((2, QuantizationDataType.int), (2, QuantizationDataType.int)), + ((3, QuantizationDataType.int), (3, QuantizationDataType.int))} + + assert set(supported_kernel_dict[quantizer_groups_to_test[2]]) == { + ((2, QuantizationDataType.int), (2, QuantizationDataType.int))} + + assert set(supported_kernel_dict[quantizer_groups_to_test[3]]) == { + ((2, QuantizationDataType.int), (2, QuantizationDataType.int))} + + assert set(supported_kernel_dict[quantizer_groups_to_test[4]]) == { + ((2, QuantizationDataType.int), (2, QuantizationDataType.int))}