From f0b3b0f9030deacb1fc8787cf33e06cf20f5f2bb Mon Sep 17 00:00:00 2001 From: Kevin Hsieh Date: Thu, 3 Oct 2024 20:55:15 -0700 Subject: [PATCH] Update RMSNorm op definition Signed-off-by: Kevin Hsieh --- .../python/aimet_torch/nn/modules/custom.py | 4 +++- .../aimet_torch/v2/nn/modules/custom.py | 20 +++++++++++++++++++ .../test/python/v2/nn/test_true_quant.py | 1 + 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/nn/modules/custom.py b/TrainingExtensions/torch/src/python/aimet_torch/nn/modules/custom.py index 1e34d844dea..c12d3f2108d 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/nn/modules/custom.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/nn/modules/custom.py @@ -933,7 +933,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for RmsNorm """ + input_dtype = x.dtype + x = x.to(dtype=torch.float32, copy=True) squared_mean = torch.mean(x * x, dim=self.axes, keepdim=True) rms = torch.sqrt(squared_mean + self.epsilon) - res = torch.div(x, rms) * self.weight + self.bias + res = (torch.div(x, rms) * self.weight + self.bias).to(dtype=input_dtype) return res diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py index db1d8196e48..7505d9b09b5 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py @@ -486,6 +486,26 @@ class QuantizedAddmm(_DispatchMixin, QuantizationMixin, Addmm): _builtin_torch_fn = torch.addmm +@QuantizationMixin.implements(RmsNorm) +class QuantizedRmsNorm(QuantizationMixin, RmsNorm): + """Custom module for RmsNorm""" + # pylint: disable=arguments-differ + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for RmsNorm + """ + if self.input_quantizers[0]: + x = self.input_quantizers[0](x) + + with self._patch_quantized_parameters(): + out = super().forward(x) + + if self.output_quantizers[0]: + out = self.output_quantizers[0](out) + + return out + + # @QuantizationMixin.implements(Square) # class QuantizedSquare(_DispatchMixin, QuantizationMixin, Square): # """ Quantized Square """ diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py index c3088ad342e..a099dad5005 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py +++ b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py @@ -954,6 +954,7 @@ def _create_quantized_module(module): # (lambda: custom.Normalize(), lambda: ...), # (lambda: custom.Pad(), lambda: ...), # (lambda: custom.GridSample(), lambda: ...), + (lambda: custom.RmsNorm([5, 2, 3], [2], 1e-5), lambda: (randn(5, 2, 3))) ])) def test_default_kernels(module_factory, input_factory): module = module_factory()