Skip to content

Commit

Permalink
Migrate 'implements' method to base class (#3342)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Sep 14, 2024
1 parent 644df6f commit 91f46e8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 30 deletions.
16 changes: 16 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class BaseQuantizationMixin(abc.ABC):
output_quantizers: nn.ModuleList
param_quantizers: nn.ModuleDict

cls_to_qcls: dict
qcls_to_cls: dict

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__quant_init__()
Expand Down Expand Up @@ -199,6 +202,19 @@ def wrap(cls, module_cls: Type[nn.Module]):
Wrap a regular module class into a quantized module class
"""

@classmethod
def implements(cls, module_cls):
"""
Decorator for registering quantized implementation of the given base class.
"""

def wrapper(quantized_cls):
cls.cls_to_qcls[module_cls] = quantized_cls
cls.qcls_to_cls[quantized_cls] = module_cls
return quantized_cls

return wrapper

@classmethod
def from_module(cls, module: nn.Module):
r"""Create an instance of quantized module from a regular module instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,6 @@ def wrap(cls, module_cls: Type[nn.Module]) -> Type[nn.Module]:
quantized_cls = type(quantized_cls_name, base_classes, {'__module__': __name__})
return cls.implements(module_cls)(quantized_cls)

@classmethod
def implements(cls, module_cls):
"""Decorator for registering a fake-quantized implementation of the given base class.
This decorator registers the defined class as the fake-quantized version of module_cls such that calling
:meth:`from_module` on an instance of module_cls will output an instance of the decorated class.
Args:
module_cls: The base :class:`torch.nn.Module` class
"""
def wrapper(quantized_cls):
cls.cls_to_qcls[module_cls] = quantized_cls
cls.qcls_to_cls[quantized_cls] = module_cls
return quantized_cls
return wrapper


class _FakeQuantizedUnaryOpMixin(FakeQuantizationMixin): # pylint: disable=abstract-method
def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,19 +332,6 @@ def wrap(cls, module_cls: Type[nn.Module]) -> Type[nn.Module]:
quantized_cls = type(quantized_cls_name, base_classes, {'__module__': __name__})
return cls.implements(module_cls)(quantized_cls)

@classmethod
def implements(cls, module_cls):
"""
Decorator for registering quantized implementation of the given base class.
"""

def wrapper(quantized_cls):
cls.cls_to_qcls[module_cls] = quantized_cls
cls.qcls_to_cls[quantized_cls] = module_cls
return quantized_cls

return wrapper


# pylint: disable=too-many-ancestors

Expand Down

0 comments on commit 91f46e8

Please sign in to comment.