From 0a0ead8201f6ea097818f773335ef7a3631826e0 Mon Sep 17 00:00:00 2001 From: Chetan Gulecha Date: Thu, 10 Oct 2024 21:35:26 +0530 Subject: [PATCH] Fixed issue in op name retrival (#3383) Signed-off-by: Chetan Gulecha --- .../torch/src/python/aimet_torch/amp/quantizer_groups.py | 2 +- .../src/python/aimet_torch/quantsim_config/quantsim_config.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/amp/quantizer_groups.py b/TrainingExtensions/torch/src/python/aimet_torch/amp/quantizer_groups.py index 63165a4905..247c3ca857 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/amp/quantizer_groups.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/amp/quantizer_groups.py @@ -450,7 +450,7 @@ def find_supported_candidates(quantizer_groups: List[QuantizerGroup], quantizer) # pylint: disable=protected-access - module = module_name_to_module_dict[quantizer]._module_to_wrap.__class__ + module = module_name_to_module_dict[quantizer]._module_to_wrap try: backend_type = aimet_op_to_backend_op_name_map[module.__class__] diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim_config/quantsim_config.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim_config/quantsim_config.py index 2aa083f422..8125789fb4 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim_config/quantsim_config.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim_config/quantsim_config.py @@ -942,11 +942,11 @@ def generate(self, module: torch.nn.Module, op_type: str) -> Tuple[dict, bool]: :return: supported_kernels and per_channel_quantization fields """ supported_kernels = [] - if module.__class__ in aimet_op_to_backend_op_name_map: + if (module.__class__ in aimet_op_to_backend_op_name_map) or (module.__class__.__name__ in aimet_op_to_backend_op_name_map): try: backend_type = aimet_op_to_backend_op_name_map[module.__class__] except KeyError: - backend_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__) + backend_type = aimet_op_to_backend_op_name_map[module.__class__.__name__] supported_kernels = self.op_type_supported_kernels.get(backend_type) if not supported_kernels: