From 20526ee727ef02c1befb5a36286af8a87e22472d Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Tue, 3 Sep 2024 14:02:40 -0700 Subject: [PATCH] Deprecate FakeQuantizationMixin (#3311) Signed-off-by: Kyunggeun Lee --- .../aimet_torch/quantsim_config/builder.py | 6 +- .../src/python/aimet_torch/v2/nn/base.py | 64 ------ .../aimet_torch/v2/nn/fake_quant/__init__.py | 188 ++++++++++++++++++ .../_legacy_impl.py} | 71 ++++++- .../python/aimet_torch/v2/nn/true_quant.py | 18 +- .../aimet_torch/v2/quantsim/quantsim.py | 3 +- .../test/python/v2/ab_test/test_quantizer_.py | 2 +- .../python/v2/ab_test/test_quantsim_export.py | 14 +- .../test/python/v2/models_/models_to_test.py | 6 +- .../test/python/v2/models_/test_models.py | 2 +- .../test/python/v2/nn/deprecated/__init__.py | 36 ++++ .../v2/nn/{ => deprecated}/test_activation.py | 2 +- .../v2/nn/{ => deprecated}/test_linear.py | 2 +- .../torch/test/python/v2/nn/test_custom_op.py | 34 ++-- .../test/python/v2/nn/test_true_quant.py | 8 +- .../torch/test/python/v2/test_seq_mse_.py | 6 +- 16 files changed, 345 insertions(+), 117 deletions(-) create mode 100644 TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/__init__.py rename TrainingExtensions/torch/src/python/aimet_torch/v2/nn/{fake_quant.py => fake_quant/_legacy_impl.py} (92%) create mode 100644 TrainingExtensions/torch/test/python/v2/nn/deprecated/__init__.py rename TrainingExtensions/torch/test/python/v2/nn/{ => deprecated}/test_activation.py (98%) rename TrainingExtensions/torch/test/python/v2/nn/{ => deprecated}/test_linear.py (99%) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim_config/builder.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim_config/builder.py index 6f79067a4e..86e3151f21 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim_config/builder.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim_config/builder.py @@ -197,11 +197,13 @@ def realize_v2_wrapper(self): :return: v2 quant wrapper with specified properties """ - from aimet_torch.v2.nn import FakeQuantizationMixin, QuantizationMixin + from aimet_torch.v2.nn import QuantizationMixin + from aimet_torch.v2.nn.fake_quant import _legacy_impl + if type(self._module_to_wrap) in QuantizationMixin.cls_to_qcls: # pylint: disable=unidiomatic-typecheck quantized_module = QuantizationMixin.from_module(self._module_to_wrap) else: - quantized_module = FakeQuantizationMixin.from_module(self._module_to_wrap) + quantized_module = _legacy_impl.FakeQuantizationMixin.from_module(self._module_to_wrap) def set_recursive(module_list, i, quantizer): """ diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py index 35f1398cdc..7b8c3f43e0 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py @@ -42,7 +42,6 @@ from typing import Type, List, Dict, Union, Iterable, Mapping, Optional import torch.nn as nn -from torch import Tensor from aimet_torch.utils import is_vector_encoding from aimet_torch.v2.quantization.affine.encoding import VectorEncoding, AffineEncoding @@ -532,69 +531,6 @@ def _remove_all_quantizers(self): return _ContextManager(action=lambda: None, cleanup=lambda: (ctx_1._cleanup(), ctx_2._cleanup())) -class _BaseQuantizedUnaryOpMixin(BaseQuantizationMixin): - def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring - x, *others = args - - if isinstance(x, Tensor) and x.is_floating_point() and self.input_quantizers[0]: - x = self.input_quantizers[0](x) - - with self._patch_quantized_parameters(): - output = super().forward(x, *others, **kwargs) - - if isinstance(output, Tensor) and output.is_floating_point() and self.output_quantizers[0]: - output = self.output_quantizers[0](output) - - return output - -class _BaseQuantizedBinaryOpMixin(BaseQuantizationMixin): - def __quant_init__(self): - super().__quant_init__() - self.input_quantizers = nn.ModuleList([None, None]) - - def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring - x, y, *others = args - - if isinstance(x, Tensor) and x.is_floating_point() and self.input_quantizers[0]: - x = self.input_quantizers[0](x) - - if isinstance(y, Tensor) and y.is_floating_point() and self.input_quantizers[1]: - y = self.input_quantizers[1](y) - - with self._patch_quantized_parameters(): - output = super().forward(x, y, *others, **kwargs) - - if isinstance(output, Tensor) and output.is_floating_point() and self.output_quantizers[0]: - output = self.output_quantizers[0](output) - - return output - - -class _BaseQuantizedTernaryOpMixin(BaseQuantizationMixin): - def __quant_init__(self): - super().__quant_init__() - self.input_quantizers = nn.ModuleList([None, None, None]) - - def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring - x, y, z, *others = args - - if isinstance(x, Tensor) and x.is_floating_point() and self.input_quantizers[0]: - x = self.input_quantizers[0](x) - - if isinstance(y, Tensor) and y.is_floating_point() and self.input_quantizers[1]: - y = self.input_quantizers[1](y) - - if isinstance(z, Tensor) and z.is_floating_point() and self.input_quantizers[2]: - z = self.input_quantizers[2](z) - - with self._patch_quantized_parameters(): - output = super().forward(x, y, z, *others, **kwargs) - - if isinstance(output, Tensor) and output.is_floating_point() and self.output_quantizers[0]: - output = self.output_quantizers[0](output) - - return output - def _remove_quantizers(quantizers, keys): orig_quantizers = {key: quantizers[key] for key in keys} diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/__init__.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/__init__.py new file mode 100644 index 0000000000..ec79b5e9f6 --- /dev/null +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/__init__.py @@ -0,0 +1,188 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +""" +Placeholder of the deprecated aimet_torch/v2/nn/fake_quant.py for backward compatibility. + +FakeQuantized- modules are now completely superseded by Quantized- modules, +and any legacy user code that tries to import FakeQuantized- modules will import Quantized- modules instead. +This package serves as a namespace that maps the legacy FakeQuantized- modules to the Quantized- equivalents +for backward compatibility. +""" + +import torch +from packaging import version +from .. import true_quant as _nn + +FakeQuantizationMixin = _nn.QuantizationMixin +FakeQuantizedAdaptiveMaxPool1d = _nn.QuantizedAdaptiveMaxPool1d +FakeQuantizedAdaptiveMaxPool2d = _nn.QuantizedAdaptiveMaxPool2d +FakeQuantizedAdaptiveMaxPool3d = _nn.QuantizedAdaptiveMaxPool3d +FakeQuantizedAlphaDropout = _nn.QuantizedAlphaDropout +FakeQuantizedAvgPool1d = _nn.QuantizedAvgPool1d +FakeQuantizedAvgPool2d = _nn.QuantizedAvgPool2d +FakeQuantizedAvgPool3d = _nn.QuantizedAvgPool3d +FakeQuantizedBCELoss = _nn.QuantizedBCELoss +FakeQuantizedBCEWithLogitsLoss = _nn.QuantizedBCEWithLogitsLoss +FakeQuantizedBatchNorm1d = _nn.QuantizedBatchNorm1d +FakeQuantizedBatchNorm2d = _nn.QuantizedBatchNorm2d +FakeQuantizedBatchNorm3d = _nn.QuantizedBatchNorm3d +FakeQuantizedBilinear = _nn.QuantizedBilinear +FakeQuantizedCELU = _nn.QuantizedCELU +FakeQuantizedCTCLoss = _nn.QuantizedCTCLoss +FakeQuantizedChannelShuffle = _nn.QuantizedChannelShuffle + +if version.parse(torch.__version__) >= version.parse("2.1.0"): + FakeQuantizedCircularPad1d = _nn.QuantizedCircularPad1d + FakeQuantizedCircularPad2d = _nn.QuantizedCircularPad2d + FakeQuantizedCircularPad3d = _nn.QuantizedCircularPad3d + +FakeQuantizedConstantPad1d = _nn.QuantizedConstantPad1d +FakeQuantizedConstantPad2d = _nn.QuantizedConstantPad2d +FakeQuantizedConstantPad3d = _nn.QuantizedConstantPad3d +FakeQuantizedConv1d = _nn.QuantizedConv1d +FakeQuantizedConv2d = _nn.QuantizedConv2d +FakeQuantizedConv3d = _nn.QuantizedConv3d +FakeQuantizedConvTranspose1d = _nn.QuantizedConvTranspose1d +FakeQuantizedConvTranspose2d = _nn.QuantizedConvTranspose2d +FakeQuantizedConvTranspose3d = _nn.QuantizedConvTranspose3d +FakeQuantizedCosineEmbeddingLoss = _nn.QuantizedCosineEmbeddingLoss +FakeQuantizedCosineSimilarity = _nn.QuantizedCosineSimilarity +FakeQuantizedCrossEntropyLoss = _nn.QuantizedCrossEntropyLoss +FakeQuantizedDropout = _nn.QuantizedDropout + +if version.parse(torch.__version__) >= version.parse("1.12.0"): + FakeQuantizedDropout1d = _nn.QuantizedDropout1d + +FakeQuantizedDropout2d = _nn.QuantizedDropout2d +FakeQuantizedDropout3d = _nn.QuantizedDropout3d +FakeQuantizedELU = _nn.QuantizedELU +FakeQuantizedEmbedding = _nn.QuantizedEmbedding +FakeQuantizedEmbeddingBag = _nn.QuantizedEmbeddingBag +FakeQuantizedFeatureAlphaDropout = _nn.QuantizedFeatureAlphaDropout +FakeQuantizedFlatten = _nn.QuantizedFlatten +FakeQuantizedFold = _nn.QuantizedFold +FakeQuantizedFractionalMaxPool2d = _nn.QuantizedFractionalMaxPool2d +FakeQuantizedFractionalMaxPool3d = _nn.QuantizedFractionalMaxPool3d +FakeQuantizedGELU = _nn.QuantizedGELU +FakeQuantizedGLU = _nn.QuantizedGLU +FakeQuantizedGRU = _nn.QuantizedGRU +FakeQuantizedGRUCell = _nn.QuantizedGRUCell +FakeQuantizedGaussianNLLLoss = _nn.QuantizedGaussianNLLLoss +FakeQuantizedGroupNorm = _nn.QuantizedGroupNorm +FakeQuantizedHardshrink = _nn.QuantizedHardshrink +FakeQuantizedHardsigmoid = _nn.QuantizedHardsigmoid +FakeQuantizedHardswish = _nn.QuantizedHardswish +FakeQuantizedHardtanh = _nn.QuantizedHardtanh +FakeQuantizedHingeEmbeddingLoss = _nn.QuantizedHingeEmbeddingLoss +FakeQuantizedHuberLoss = _nn.QuantizedHuberLoss +FakeQuantizedInstanceNorm1d = _nn.QuantizedInstanceNorm1d +FakeQuantizedInstanceNorm2d = _nn.QuantizedInstanceNorm2d +FakeQuantizedInstanceNorm3d = _nn.QuantizedInstanceNorm3d +FakeQuantizedKLDivLoss = _nn.QuantizedKLDivLoss +FakeQuantizedL1Loss = _nn.QuantizedL1Loss +FakeQuantizedLPPool1d = _nn.QuantizedLPPool1d +FakeQuantizedLPPool2d = _nn.QuantizedLPPool2d +FakeQuantizedLSTM = _nn.QuantizedLSTM +FakeQuantizedLSTMCell = _nn.QuantizedLSTMCell +FakeQuantizedLayerNorm = _nn.QuantizedLayerNorm +FakeQuantizedLeakyReLU = _nn.QuantizedLeakyReLU +FakeQuantizedLinear = _nn.QuantizedLinear +FakeQuantizedLocalResponseNorm = _nn.QuantizedLocalResponseNorm +FakeQuantizedLogSigmoid = _nn.QuantizedLogSigmoid +FakeQuantizedLogSoftmax = _nn.QuantizedLogSoftmax +FakeQuantizedMSELoss = _nn.QuantizedMSELoss +FakeQuantizedMarginRankingLoss = _nn.QuantizedMarginRankingLoss +FakeQuantizedMaxPool1d = _nn.QuantizedMaxPool1d +FakeQuantizedMaxPool2d = _nn.QuantizedMaxPool2d +FakeQuantizedMaxPool3d = _nn.QuantizedMaxPool3d +FakeQuantizedMaxUnpool1d = _nn.QuantizedMaxUnpool1d +FakeQuantizedMaxUnpool2d = _nn.QuantizedMaxUnpool2d +FakeQuantizedMaxUnpool3d = _nn.QuantizedMaxUnpool3d +FakeQuantizedMish = _nn.QuantizedMish +FakeQuantizedMultiLabelMarginLoss = _nn.QuantizedMultiLabelMarginLoss +FakeQuantizedMultiLabelSoftMarginLoss = _nn.QuantizedMultiLabelSoftMarginLoss +FakeQuantizedMultiMarginLoss = _nn.QuantizedMultiMarginLoss +FakeQuantizedNLLLoss = _nn.QuantizedNLLLoss +FakeQuantizedNLLLoss2d = _nn.QuantizedNLLLoss2d +FakeQuantizedPReLU = _nn.QuantizedPReLU +FakeQuantizedPairwiseDistance = _nn.QuantizedPairwiseDistance +FakeQuantizedPixelShuffle = _nn.QuantizedPixelShuffle +FakeQuantizedPixelUnshuffle = _nn.QuantizedPixelUnshuffle +FakeQuantizedPoissonNLLLoss = _nn.QuantizedPoissonNLLLoss +FakeQuantizedRNN = _nn.QuantizedRNN +FakeQuantizedRNNCell = _nn.QuantizedRNNCell +FakeQuantizedRReLU = _nn.QuantizedRReLU +FakeQuantizedReLU = _nn.QuantizedReLU +FakeQuantizedReLU6 = _nn.QuantizedReLU6 +FakeQuantizedReflectionPad1d = _nn.QuantizedReflectionPad1d +FakeQuantizedReflectionPad2d = _nn.QuantizedReflectionPad2d + +if version.parse(torch.__version__) >= version.parse("1.10.0"): + FakeQuantizedReflectionPad3d = _nn.QuantizedReflectionPad3d + +FakeQuantizedReplicationPad1d = _nn.QuantizedReplicationPad1d +FakeQuantizedReplicationPad2d = _nn.QuantizedReplicationPad2d +FakeQuantizedReplicationPad3d = _nn.QuantizedReplicationPad3d +FakeQuantizedSELU = _nn.QuantizedSELU +FakeQuantizedSiLU = _nn.QuantizedSiLU +FakeQuantizedSigmoid = _nn.QuantizedSigmoid +FakeQuantizedSmoothL1Loss = _nn.QuantizedSmoothL1Loss +FakeQuantizedSoftMarginLoss = _nn.QuantizedSoftMarginLoss +FakeQuantizedSoftmax = _nn.QuantizedSoftmax +FakeQuantizedSoftmax2d = _nn.QuantizedSoftmax2d +FakeQuantizedSoftmin = _nn.QuantizedSoftmin +FakeQuantizedSoftplus = _nn.QuantizedSoftplus +FakeQuantizedSoftshrink = _nn.QuantizedSoftshrink +FakeQuantizedSoftsign = _nn.QuantizedSoftsign +FakeQuantizedTanh = _nn.QuantizedTanh +FakeQuantizedTanhshrink = _nn.QuantizedTanhshrink +FakeQuantizedThreshold = _nn.QuantizedThreshold +FakeQuantizedTripletMarginLoss = _nn.QuantizedTripletMarginLoss +FakeQuantizedTripletMarginWithDistanceLoss = _nn.QuantizedTripletMarginWithDistanceLoss +FakeQuantizedUnflatten = _nn.QuantizedUnflatten +FakeQuantizedUnfold = _nn.QuantizedUnfold +FakeQuantizedUpsample = _nn.QuantizedUpsample +FakeQuantizedUpsamplingBilinear2d = _nn.QuantizedUpsamplingBilinear2d +FakeQuantizedUpsamplingNearest2d = _nn.QuantizedUpsamplingNearest2d + +if version.parse(torch.__version__) >= version.parse("2.1.0"): + FakeQuantizedZeroPad1d = _nn.QuantizedZeroPad1d + +FakeQuantizedZeroPad2d = _nn.QuantizedZeroPad2d + +if version.parse(torch.__version__) >= version.parse("2.1.0"): + FakeQuantizedZeroPad3d = _nn.QuantizedZeroPad3d diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/_legacy_impl.py similarity index 92% rename from TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant.py rename to TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/_legacy_impl.py index 9258f914da..f815640a83 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/_legacy_impl.py @@ -35,7 +35,7 @@ # @@-COPYRIGHT-END-@@ # ============================================================================= # pylint: disable=too-many-lines, wrong-import-order -"""Fake-quantized modules""" +"""Fake-quantized modules (deprecated)""" from packaging import version from collections import OrderedDict @@ -50,8 +50,8 @@ from torch.nn.utils.rnn import PackedSequence from torch.utils._pytree import tree_map -from .base import BaseQuantizationMixin, _BaseQuantizedUnaryOpMixin, _BaseQuantizedBinaryOpMixin, _BaseQuantizedTernaryOpMixin # pylint: disable=import-error -from .modules import custom # pylint: disable=import-error +from ..base import BaseQuantizationMixin # pylint: disable=import-error +from ..modules import custom # pylint: disable=import-error class FakeQuantMeta(abc.ABCMeta): @@ -178,14 +178,67 @@ def wrapper(quantized_cls): return wrapper -class _FakeQuantizedUnaryOpMixin(_BaseQuantizedUnaryOpMixin, FakeQuantizationMixin): # pylint: disable=abstract-method - pass +class _FakeQuantizedUnaryOpMixin(FakeQuantizationMixin): # pylint: disable=abstract-method + def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring + x, *others = args + + if isinstance(x, Tensor) and x.is_floating_point() and self.input_quantizers[0]: + x = self.input_quantizers[0](x) + + with self._patch_quantized_parameters(): + output = super().forward(x, *others, **kwargs) + + if isinstance(output, Tensor) and output.is_floating_point() and self.output_quantizers[0]: + output = self.output_quantizers[0](output) + + return output + +class _FakeQuantizedBinaryOpMixin(FakeQuantizationMixin): # pylint: disable=abstract-method + def __quant_init__(self): + super().__quant_init__() + self.input_quantizers = nn.ModuleList([None, None]) + + def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring + x, y, *others = args + + if isinstance(x, Tensor) and x.is_floating_point() and self.input_quantizers[0]: + x = self.input_quantizers[0](x) + + if isinstance(y, Tensor) and y.is_floating_point() and self.input_quantizers[1]: + y = self.input_quantizers[1](y) + + with self._patch_quantized_parameters(): + output = super().forward(x, y, *others, **kwargs) + + if isinstance(output, Tensor) and output.is_floating_point() and self.output_quantizers[0]: + output = self.output_quantizers[0](output) + + return output + +class _FakeQuantizedTernaryOpMixin(FakeQuantizationMixin): # pylint: disable=abstract-method + def __quant_init__(self): + super().__quant_init__() + self.input_quantizers = nn.ModuleList([None, None, None]) -class _FakeQuantizedBinaryOpMixin(_BaseQuantizedBinaryOpMixin, FakeQuantizationMixin): # pylint: disable=abstract-method - pass + def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring + x, y, z, *others = args -class _FakeQuantizedTernaryOpMixin(_BaseQuantizedTernaryOpMixin, FakeQuantizationMixin): # pylint: disable=abstract-method - pass + if isinstance(x, Tensor) and x.is_floating_point() and self.input_quantizers[0]: + x = self.input_quantizers[0](x) + + if isinstance(y, Tensor) and y.is_floating_point() and self.input_quantizers[1]: + y = self.input_quantizers[1](y) + + if isinstance(z, Tensor) and z.is_floating_point() and self.input_quantizers[2]: + z = self.input_quantizers[2](z) + + with self._patch_quantized_parameters(): + output = super().forward(x, y, z, *others, **kwargs) + + if isinstance(output, Tensor) and output.is_floating_point() and self.output_quantizers[0]: + output = self.output_quantizers[0](output) + + return output diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py index e87d2bae1f..9f5ffa199a 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py @@ -44,6 +44,7 @@ from collections import OrderedDict from typing import Type, Any, Optional, Callable, Dict from weakref import WeakKeyDictionary +import warnings import torch import torch.nn as nn @@ -116,7 +117,20 @@ def _exit_compute_encodings(qmodule): _QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS[qmodule] -= 1 -class QuantizationMixin(BaseQuantizationMixin): # pylint: disable=abstract-method +class QuantizationMixinMeta(ABCMeta): + """Sets :meth:`forward` to :meth:`quantized_forward` if only :meth:`quantized_forward` is defined + """ + + def __new__(mcs, name, bases, namespace, **kwargs): + if "quantized_forward" in namespace and "forward" not in namespace: + warnings.warn("Support for defining `quantized_forward` in place of `forward` method will be deprecated, " + "please use `forward` instead.", + DeprecationWarning, stacklevel=2) + namespace["forward"] = namespace["quantized_forward"] + return super().__new__(mcs, name, bases, namespace, **kwargs) + + +class QuantizationMixin(BaseQuantizationMixin, metaclass=QuantizationMixinMeta): # pylint: disable=abstract-method """Mixin that adds quantization functionality on top of regular pytorch modules. :class:`QuantizationMixin` provides all the same behavior as :class:`FakeQuantizationMixin`, and by default, a @@ -389,7 +403,7 @@ def _dispatch(torch_func: Callable, custom_impl: Callable): _dispatcher.__exit__(None, None, None) -class _DispatchMeta(ABCMeta): +class _DispatchMeta(QuantizationMixinMeta): def __new__(mcs, name, bases, namespace, **kwargs): """ Sanity check for class definitions of dispatch-based quantized modules diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py index 914b039e97..1685de0988 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py @@ -45,7 +45,6 @@ from aimet_torch.quantsim import QuantizationSimModel as V1QuantizationSimModel, logger import aimet_torch.quantsim as quantsim_v1 from aimet_torch.v2 import nn as aimet_nn -from aimet_torch.v2.nn import FakeQuantizationMixin from aimet_torch.v2.nn import BaseQuantizationMixin from aimet_torch.quantsim_config.builder import LazyQuantizeWrapper from aimet_torch.v2.quantization.base import QuantizerBase @@ -111,7 +110,7 @@ def __init__(self, model, *args, **kwargs): # pylint: disable=arguments-differ module.to(device=device) @staticmethod - def _realize_quant_wrapper(module: LazyQuantizeWrapper) -> FakeQuantizationMixin: + def _realize_quant_wrapper(module: LazyQuantizeWrapper) -> BaseQuantizationMixin: """ Make wrapper builder into v2 quant wrapper diff --git a/TrainingExtensions/torch/test/python/v2/ab_test/test_quantizer_.py b/TrainingExtensions/torch/test/python/v2/ab_test/test_quantizer_.py index 50613352d1..cd84521b8c 100644 --- a/TrainingExtensions/torch/test/python/v2/ab_test/test_quantizer_.py +++ b/TrainingExtensions/torch/test/python/v2/ab_test/test_quantizer_.py @@ -65,7 +65,7 @@ from aimet_torch.qc_quantize_op import QcQuantizeWrapper, QcQuantizeStandalone, StaticGridQuantWrapper from aimet_torch.quantsim import check_accumulator_overflow, compute_encodings_for_sims import aimet_torch.v2.nn as aimet_nn -from aimet_torch.v2.nn.fake_quant import _FakeQuantizedUnaryOpMixin +from aimet_torch.v2.nn.fake_quant._legacy_impl import _FakeQuantizedUnaryOpMixin from aimet_torch.v2.quantization.affine import QuantizeDequantize from aimet_torch.v2.quantization.float import FloatQuantizeDequantize from aimet_torch.v2.quantsim import QuantizationSimModel diff --git a/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_export.py b/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_export.py index b5d89c7c1d..e1456f5e9f 100644 --- a/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_export.py +++ b/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_export.py @@ -46,7 +46,7 @@ from torchvision.models import resnet18 import aimet_torch.v2.nn as aimet_nn -from aimet_torch.v2.nn import FakeQuantizationMixin +from aimet_torch.v2.nn import QuantizationMixin from aimet_torch.v2.quantization.affine import QuantizeDequantize from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer from aimet_torch.nn.modules.custom import Add @@ -93,7 +93,7 @@ def test_onnx_export(self): model = DummyModel(in_channels=input_shape[1]) sim_model = copy.deepcopy(model) for name, module in sim_model.named_children(): - quantized_module = FakeQuantizationMixin.from_module(module) + quantized_module = QuantizationMixin.from_module(module) if name == "conv1": input_quantizer = QuantizeDequantize((), @@ -148,7 +148,7 @@ def test_export_to_onnx_direct(self): sim_model = copy.deepcopy(model) dummy_input = (torch.rand(1, 1, 28, 28), torch.rand(1, 1, 28, 28)) for name, module in sim_model.named_children(): - quantized_module = FakeQuantizationMixin.from_module(module) + quantized_module = QuantizationMixin.from_module(module) if name == "conv1_a": input_quantizer = QuantizeDequantize((), @@ -207,7 +207,7 @@ def test_encodings_propagation(self): pixel_shuffle = torch.nn.PixelShuffle(2) model = torch.nn.Sequential(pixel_shuffle) - quantized_pixel_shuffle = FakeQuantizationMixin.from_module(pixel_shuffle) + quantized_pixel_shuffle = QuantizationMixin.from_module(pixel_shuffle) quantized_pixel_shuffle.input_quantizers[0] = QuantizeDequantize((), bitwidth=8, symmetric=False, @@ -258,7 +258,7 @@ def test_multi_output_onnx_op(self): model = ModelWith5Output() dummy_input = torch.randn(1, 3, 224, 224) sim_model = copy.deepcopy(model) - sim_model.cust = FakeQuantizationMixin.from_module(sim_model.cust) + sim_model.cust = QuantizationMixin.from_module(sim_model.cust) sim_model.cust.input_quantizers[0] = QuantizeDequantize((), bitwidth=8, symmetric=False, @@ -292,7 +292,7 @@ def test_mapping_encoding_for_torch_module_with_multiple_onnx_ops(self): model = SoftMaxAvgPoolModel() sim_model = copy.deepcopy(model) - sim_model.sfmax = FakeQuantizationMixin.from_module(sim_model.sfmax) + sim_model.sfmax = QuantizationMixin.from_module(sim_model.sfmax) sim_model.sfmax.input_quantizers[0] = QuantizeDequantize((), bitwidth=8, symmetric=False, @@ -302,7 +302,7 @@ def test_mapping_encoding_for_torch_module_with_multiple_onnx_ops(self): symmetric=False, encoding_analyzer=MinMaxEncodingAnalyzer(())) - sim_model.avgpool = FakeQuantizationMixin.from_module(sim_model.avgpool) + sim_model.avgpool = QuantizationMixin.from_module(sim_model.avgpool) sim_model.avgpool.input_quantizers[0] = QuantizeDequantize((), bitwidth=8, symmetric=False, diff --git a/TrainingExtensions/torch/test/python/v2/models_/models_to_test.py b/TrainingExtensions/torch/test/python/v2/models_/models_to_test.py index 5cd6c26149..4bc511ab49 100644 --- a/TrainingExtensions/torch/test/python/v2/models_/models_to_test.py +++ b/TrainingExtensions/torch/test/python/v2/models_/models_to_test.py @@ -40,7 +40,7 @@ from torch import nn import aimet_torch.nn.modules.custom as aimet_modules -from aimet_torch.v2.nn import FakeQuantizationMixin +from aimet_torch.v2.nn import QuantizationMixin class SimpleConditional(torch.nn.Module): @@ -398,8 +398,8 @@ def forward(self, *inputs): return x -@FakeQuantizationMixin.implements(ModuleWith5Output) -class FakeQuantizationModuleWith5Output(FakeQuantizationMixin, ModuleWith5Output): +@QuantizationMixin.implements(ModuleWith5Output) +class QuantizationModuleWith5Output(QuantizationMixin, ModuleWith5Output): def __quant_init__(self): super().__quant_init__() self.output_quantizers = torch.nn.ModuleList([None, None, None, None, None]) diff --git a/TrainingExtensions/torch/test/python/v2/models_/test_models.py b/TrainingExtensions/torch/test/python/v2/models_/test_models.py index dab734893b..2a54a4d014 100644 --- a/TrainingExtensions/torch/test/python/v2/models_/test_models.py +++ b/TrainingExtensions/torch/test/python/v2/models_/test_models.py @@ -999,7 +999,7 @@ def forward(self, features, rois): sampling_ratio = 0) -from aimet_torch.v2.nn.fake_quant import _FakeQuantizedUnaryOpMixin +from aimet_torch.v2.nn.fake_quant._legacy_impl import _FakeQuantizedUnaryOpMixin FakeQuantizedRoiAlignPyTorch = _FakeQuantizedUnaryOpMixin.wrap(RoiAlignPyTorch) diff --git a/TrainingExtensions/torch/test/python/v2/nn/deprecated/__init__.py b/TrainingExtensions/torch/test/python/v2/nn/deprecated/__init__.py new file mode 100644 index 0000000000..e2bd9499cb --- /dev/null +++ b/TrainingExtensions/torch/test/python/v2/nn/deprecated/__init__.py @@ -0,0 +1,36 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_activation.py b/TrainingExtensions/torch/test/python/v2/nn/deprecated/test_activation.py similarity index 98% rename from TrainingExtensions/torch/test/python/v2/nn/test_activation.py rename to TrainingExtensions/torch/test/python/v2/nn/deprecated/test_activation.py index e5993363a6..cf91fd740f 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_activation.py +++ b/TrainingExtensions/torch/test/python/v2/nn/deprecated/test_activation.py @@ -40,7 +40,7 @@ import torch.nn.functional as F from aimet_torch.v2.quantization.affine.backends import quantize_dequantize from aimet_torch.v2.quantization.affine import QuantizeDequantize -from aimet_torch.v2.nn import FakeQuantizedSoftmax, FakeQuantizedReshape +from aimet_torch.v2.nn.fake_quant._legacy_impl import FakeQuantizedSoftmax, FakeQuantizedReshape from aimet_torch.v2.quantization.affine.encoding import AffineEncoding from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer from aimet_torch.v2.quantization.tensor import DequantizedTensor diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_linear.py b/TrainingExtensions/torch/test/python/v2/nn/deprecated/test_linear.py similarity index 99% rename from TrainingExtensions/torch/test/python/v2/nn/test_linear.py rename to TrainingExtensions/torch/test/python/v2/nn/deprecated/test_linear.py index 7cc9570b14..1b36308e0e 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_linear.py +++ b/TrainingExtensions/torch/test/python/v2/nn/deprecated/test_linear.py @@ -41,7 +41,7 @@ import torch.nn.functional as F from aimet_torch.v2.quantization.affine.backends import quantize_dequantize from aimet_torch.v2.quantization.affine import QuantizeDequantize -from aimet_torch.v2.nn import FakeQuantizedLinear, FakeQuantizationMixin +from aimet_torch.v2.nn.fake_quant._legacy_impl import FakeQuantizedLinear, FakeQuantizationMixin from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_custom_op.py b/TrainingExtensions/torch/test/python/v2/nn/test_custom_op.py index 9620ee291b..0c9811327b 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_custom_op.py +++ b/TrainingExtensions/torch/test/python/v2/nn/test_custom_op.py @@ -37,7 +37,7 @@ import pytest import torch -from aimet_torch.v2.nn import FakeQuantizationMixin +from aimet_torch.v2.nn import QuantizationMixin class CustomOp(torch.nn.Module): @@ -46,43 +46,43 @@ def forward(self, input): return input * 2 + 1 -class TestFakeQuantizedCustomOp: +class TestQuantizedCustomOp: def test_custom_op_from_module_unregistered(self): with pytest.raises(RuntimeError): - _ = FakeQuantizationMixin.from_module(CustomOp()) + _ = QuantizationMixin.from_module(CustomOp()) def test_custom_op_from_module_registered(self): try: - @FakeQuantizationMixin.implements(CustomOp) - class FakeQuantizedCustomOp(FakeQuantizationMixin, CustomOp): + @QuantizationMixin.implements(CustomOp) + class QuantizedCustomOp(QuantizationMixin, CustomOp): def quantized_forward(self, x): x = super().forward(x) return self.output_quantizers[0](x) - quantized_custom_op = FakeQuantizationMixin.from_module(CustomOp()) - assert isinstance(quantized_custom_op, FakeQuantizedCustomOp) + quantized_custom_op = QuantizationMixin.from_module(CustomOp()) + assert isinstance(quantized_custom_op, QuantizedCustomOp) - quantized_custom_op_ = FakeQuantizationMixin.from_module(CustomOp()) - assert isinstance(quantized_custom_op_, FakeQuantizedCustomOp) + quantized_custom_op_ = QuantizationMixin.from_module(CustomOp()) + assert isinstance(quantized_custom_op_, QuantizedCustomOp) finally: # Unregister CustomOp so as not to affect other test functions - FakeQuantizationMixin.cls_to_qcls.pop(CustomOp) + QuantizationMixin.cls_to_qcls.pop(CustomOp) def test_custom_op_wrap_registered(self): try: - @FakeQuantizationMixin.implements(CustomOp) - class FakeQuantizedCustomOp(FakeQuantizationMixin, CustomOp): + @QuantizationMixin.implements(CustomOp) + class QuantizedCustomOp(QuantizationMixin, CustomOp): def quantized_forward(self, x): x = super().forward(x) return self.output_quantizers[0](x) - quantized_custom_op_cls = FakeQuantizationMixin.wrap(CustomOp) - assert quantized_custom_op_cls is FakeQuantizedCustomOp + quantized_custom_op_cls = QuantizationMixin.wrap(CustomOp) + assert quantized_custom_op_cls is QuantizedCustomOp - quantized_custom_op_cls_ = FakeQuantizationMixin.wrap(CustomOp) - assert quantized_custom_op_cls_ is FakeQuantizedCustomOp + quantized_custom_op_cls_ = QuantizationMixin.wrap(CustomOp) + assert quantized_custom_op_cls_ is QuantizedCustomOp finally: # Unregister CustomOp so as not to affect other test functions - FakeQuantizationMixin.cls_to_qcls.pop(CustomOp) + QuantizationMixin.cls_to_qcls.pop(CustomOp) diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py index bf82fad414..d94ce1a3f7 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py +++ b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py @@ -63,8 +63,8 @@ QuantizedSoftmax, QuantizedLayerNorm, QuantizedGroupNorm, - FakeQuantizationMixin, ) +from aimet_torch.v2.nn.fake_quant import _legacy_impl from aimet_torch.v2.nn.true_quant import _dispatch, _dispatch_table from aimet_torch.v2.quantization.affine import AffineEncoding from aimet_torch.v2.quantization.tensor import QuantizedTensorBase, QuantizedTensor, DequantizedTensor @@ -436,7 +436,7 @@ def test_layers_no_params(self, module_factory, input_factory): if not isinstance(inputs, (tuple, list)): inputs = (inputs,) - fq_layer = FakeQuantizationMixin.from_module(layer) + fq_layer = _legacy_impl.FakeQuantizationMixin.from_module(layer) tq_layer = QuantizationMixin.from_module(layer) for i, _ in enumerate(inputs): fq_layer.input_quantizers[i] = QuantizeDequantize(shape=(), bitwidth=8, symmetric=False) @@ -474,7 +474,7 @@ def test_layers_with_weight(self, module_factory, input_factory): layer = module_factory() input = input_factory() - fq_layer = FakeQuantizationMixin.from_module(layer) + fq_layer = _legacy_impl.FakeQuantizationMixin.from_module(layer) tq_layer = QuantizationMixin.from_module(layer) fq_layer.input_quantizers[0] = QuantizeDequantize(shape=(), bitwidth=8, symmetric=False) fq_layer.output_quantizers[0] = QuantizeDequantize(shape=(), bitwidth=8, symmetric=False) @@ -680,7 +680,7 @@ def test_dispatch_sanity(): def _create_legacy_fake_quantized_module(module): - qmodule = aimet.nn.fake_quant.FakeQuantizationMixin.from_module(module) + qmodule = _legacy_impl.FakeQuantizationMixin.from_module(module) for i, _ in enumerate(qmodule.input_quantizers): qmodule.input_quantizers[i] = QuantizeDequantize([], 8, False) diff --git a/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py b/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py index 18f9640d85..2fc801d812 100644 --- a/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py +++ b/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py @@ -47,7 +47,7 @@ from aimet_torch.utils import create_fake_data_loader from aimet_torch.v2.quantsim import QuantizationSimModel -from aimet_torch.v2.nn.fake_quant import FakeQuantizationMixin +from aimet_torch.v2.nn import QuantizationMixin from aimet_torch.v2.quantization.affine import QuantizeDequantize from aimet_torch.v2.seq_mse import apply_seq_mse, get_candidates, optimize_module, SeqMseParams, SequentialMse from .models_.mnist_torch_model import Net @@ -160,7 +160,7 @@ def test_optimize_module_linear(self, enable_pcq, param_bw, loss_fn, qparam_requ """ test optimize module for linear """ torch.manual_seed(0) linear = torch.nn.Linear(64, 128) - wrapper = FakeQuantizationMixin.from_module(linear) + wrapper = QuantizationMixin.from_module(linear) if enable_pcq: quantizer_shape = [linear.weight.shape[0], 1] else: @@ -194,7 +194,7 @@ def test_optimize_module_conv(self, enable_pcq, param_bw, loss_fn): """ test optimize module for linear """ torch.manual_seed(0) conv = torch.nn.Conv2d(3, 32, 3) - wrapper = FakeQuantizationMixin.from_module(conv) + wrapper = QuantizationMixin.from_module(conv) if enable_pcq: quantizer_shape = [conv.weight.shape[0], 1, 1, 1] else: