Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Pytorch AMP supported kernels logic #3396

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def rename_nodes(G: nx.DiGraph, module_name_to_module_dict: Dict, dotted_name2op
for node in G_copy.nodes:
if node not in ["input_ops", "output_ops"]:
if ("input_ops", node) in G_copy.edges: # quantizer groups with an input quantizer
try:
module_name, _ = find_wrapper_module(node, module_name_to_module_dict)
module_name, _ = find_wrapper_module(node, module_name_to_module_dict)
if module_name is not None:
mapping = {node: module_name + "_output"}
G = nx.relabel_nodes(G, mapping)

Expand All @@ -173,14 +173,14 @@ def rename_nodes(G: nx.DiGraph, module_name_to_module_dict: Dict, dotted_name2op

G.nodes[new_node_name]["tensor_dims"] = input_shape
G.nodes[new_node_name]["tensor_size"] = input_size
except:
else:
_logger.info("did not change node name: %s", node)
else:
try:
module_name, _ = find_wrapper_module(node, module_name_to_module_dict)
module_name, _ = find_wrapper_module(node, module_name_to_module_dict)
if module_name is not None:
mapping = {node: module_name + "_output"}
G = nx.relabel_nodes(G, mapping)
except:
else:
_logger.info("did not change node name: %s", node)

G.remove_node("input_ops")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def get_supported_candidates_for_quantizers(quantizers: List,
# Store candidates for quantizer
store_candidates_for_quantizer(supported_kernels, op, amp_candidates_set, act_bw_set, act_and_param_set,
act_only_set, null_intersection_ops)
break

# Default candidate selected if op not found in supported kernels
if not ops_found:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

""" Find quantizer groups in a model """
import itertools
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
from dataclasses import dataclass, field
import torch
Expand Down Expand Up @@ -66,6 +66,7 @@ class QuantizerGroup(QuantizerGroupBase):
input_quantizers: Tuple[str, ...] = field(default_factory=tuple)
output_quantizers: Tuple[str, ...] = field(default_factory=tuple)
parameter_quantizers: Tuple[str, ...] = field(default_factory=tuple)
supported_kernel_ops: Tuple[str, ...] = field(default_factory=tuple)

def get_candidate(self, name_to_quantizer_dict: Dict) -> CANDIDATE_WITH_DTYPE:
"""
Expand Down Expand Up @@ -167,7 +168,8 @@ def get_input_quantizer_modules(self):
return tuple(sorted(result))


def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> Tuple[str, torch.nn.Module]:
def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> \
Tuple[Optional[str], Optional[torch.nn.Module]]:
"""
Finds quantization (wrapping) module corresponding to the wrapper module's dotted name
:param op_name: Dotted name of op as represented in connected graph
Expand All @@ -179,7 +181,7 @@ def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> Tu
if module_name in module_name_to_quantizer_dict:
return module_name, module_name_to_quantizer_dict[module_name]
# Else it is a functional op
raise KeyError
return None, None


def get_module_name_to_module_dict(sim: QuantizationSimModel) -> Dict:
Expand Down Expand Up @@ -320,11 +322,8 @@ def get_input_and_param_quantizers(
"""
input_quantizers = []
parameter_quantizers = []
try:
module_name, module = find_wrapper_module(child, module_name_to_module_dict)
except KeyError:
pass
else:
module_name, module = find_wrapper_module(child, module_name_to_module_dict)
if module_name is not None:
for idx, input_quantizer in enumerate(module.input_quantizers):
if input_quantizer.enabled:
input_quantizers.append(module_name + '_input_quantizer_idx_' + str(idx))
Expand Down Expand Up @@ -358,9 +357,14 @@ def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[Quantize
# Add one quantizer group for each input and it's weight param
input_quantizers, parameter_quantizers = get_input_and_param_quantizers(child, module_name_to_module_dict)
if input_quantizers or parameter_quantizers:
child_module_name, _ = find_wrapper_module(child, module_name_to_module_dict)
supported_kernel_ops = []
if child_module_name is not None:
supported_kernel_ops.append(child_module_name)
quantizer_group = QuantizerGroup(
input_quantizers=input_quantizers,
parameter_quantizers=parameter_quantizers
parameter_quantizers=parameter_quantizers,
supported_kernel_ops=tuple(supported_kernel_ops)
)
quantizer_groups.append(quantizer_group)
logger.debug('\n Quantizer Group Added: %s', quantizer_group)
Expand All @@ -375,42 +379,44 @@ def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[Quantize
if not isinstance(parents, tuple):
parents = [parents]
for parent in parents:
try:
module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
except KeyError:
continue
module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
if module is not None:
for output_quantizer in module.output_quantizers:
if output_quantizer.enabled:
output_quantizers += (module_name,)

supported_kernel_ops = []
for child in children:
input_q, param_q = get_input_and_param_quantizers(child, module_name_to_module_dict)
input_quantizers += input_q
parameter_quantizers += param_q
child_module_name, _ = find_wrapper_module(child, module_name_to_module_dict)
if child_module_name is not None:
supported_kernel_ops.append(child_module_name)

# Don't add quantizer group if it is empty
if input_quantizers or output_quantizers or parameter_quantizers:
quantizer_group = QuantizerGroup(
input_quantizers=input_quantizers,
output_quantizers=output_quantizers,
parameter_quantizers=parameter_quantizers
parameter_quantizers=parameter_quantizers,
supported_kernel_ops=tuple(supported_kernel_ops)
)
quantizer_groups.append(quantizer_group)
logger.debug('\n Quantizer Group added: %s', quantizer_group)

if 'output_ops' in parent_child_op_groups:
for parent in parent_child_op_groups['output_ops']:
# Add one quantizer group for each input and it's weight param
try:
module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
except KeyError:
continue
module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
if module is not None:
for output_quantizer in module.output_quantizers:
if output_quantizer.enabled:
# Using empty supported kernel ops so that model output quantizers are able to consider all
# default candidates
quantizer_group = QuantizerGroup(
output_quantizers=(module_name,),
supported_kernel_ops=tuple()
)
quantizer_groups.append(quantizer_group)
logger.debug('\n Quantizer Group added: %s', quantizer_group)
Expand All @@ -437,42 +443,42 @@ def find_supported_candidates(quantizer_groups: List[QuantizerGroup],
quantizers_with_supported_candidates = defaultdict(list)

# pylint: disable=too-many-nested-blocks
# pylint: disable=protected-access
for quantizer_group in quantizer_groups:
quantizers = sorted(set(itertools.chain(quantizer_group.get_input_quantizer_modules(),
quantizer_group.output_quantizers,
quantizer_group.parameter_quantizers)))

# quantizers are now unique ops present in the given quantizer_group
onnx_ops = defaultdict(list)
for quantizer in quantizers:
if quantizer not in module_name_to_module_dict:
raise RuntimeError('module_name_to_module_dict does not contain an entry for the quantizer:',
quantizer)

# pylint: disable=protected-access
module = module_name_to_module_dict[quantizer]._module_to_wrap

supported_kernel_types = set()
for supported_kernel_op in quantizer_group.supported_kernel_ops:
module = module_name_to_module_dict[supported_kernel_op]._module_to_wrap
try:
backend_type = aimet_op_to_backend_op_name_map[module.__class__]
except KeyError:
backend_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__)

if backend_type in supported_kernels:
onnx_ops[quantizer] = [backend_type]
supported_kernel_types.add(backend_type)
else:
onnx_types = onnx_utils.map_torch_types_to_onnx.get(
type(module_name_to_module_dict[quantizer]._module_to_wrap), [])
type(module_name_to_module_dict[supported_kernel_op]._module_to_wrap), [])

if not onnx_types:
logger.warning("No mapping found for %s in the torch to onnx op type mapping dictionary.",
str(type(module_name_to_module_dict[quantizer]._module_to_wrap)))
str(type(module_name_to_module_dict[supported_kernel_op]._module_to_wrap)))

supported_kernel_types.update(onnx_types)

onnx_ops[quantizer] = onnx_types
for onnx_type in onnx_types:
if onnx_type not in supported_kernels.keys():
if module in supported_kernels:
supported_kernels[onnx_type] = supported_kernels[module]

for quantizer in quantizers:
onnx_ops[quantizer] = list(supported_kernel_types)

supported_kernels_for_quantizers = get_supported_candidates_for_quantizers(quantizers,
onnx_ops,
supported_kernels,
Expand Down
10 changes: 4 additions & 6 deletions TrainingExtensions/torch/test/python/test_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,19 +1111,17 @@ def test_supported_candidates_2(
# default_supported_kernels and conv_supported_kernels are the configurations added in the json file above.
default_supported_kernels = [((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
((16, QuantizationDataType.float), (16, QuantizationDataType.float)),
((8, QuantizationDataType.float), (16, QuantizationDataType.float))]
((8, QuantizationDataType.int), (16, QuantizationDataType.int))]

conv_supported_kernels = [((16, QuantizationDataType.float), (16, QuantizationDataType.float)),
((8, QuantizationDataType.int), (16, QuantizationDataType.int))]

for quantizer_group, quantizer_candidates in algo._supported_candidates_per_quantizer_group.items():
quantizers = sorted(itertools.chain(quantizer_group.get_input_quantizer_modules(),
quantizer_group.output_quantizers,
quantizer_group.parameter_quantizers))
supported_kernel_ops = quantizer_group.supported_kernel_ops
onnx_types = []
for q in quantizers:
for op in supported_kernel_ops:
onnx_types.append(
onnx_utils.map_torch_types_to_onnx.get(type(algo._module_name_dict[q]._module_to_wrap)))
onnx_utils.map_torch_types_to_onnx.get(type(algo._module_name_dict[op]._module_to_wrap)))

# verify to make sure the candidates returned is always part of amp_candidates and they are part of
# "Defaults" or "Conv" appropriately
Expand Down
Loading
Loading