Skip to content

Commit

Permalink
Update RMSNorm op definition
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <[email protected]>
  • Loading branch information
quic-klhsieh authored Oct 4, 2024
1 parent 8b1310b commit f0b3b0f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for RmsNorm
"""
input_dtype = x.dtype
x = x.to(dtype=torch.float32, copy=True)
squared_mean = torch.mean(x * x, dim=self.axes, keepdim=True)
rms = torch.sqrt(squared_mean + self.epsilon)
res = torch.div(x, rms) * self.weight + self.bias
res = (torch.div(x, rms) * self.weight + self.bias).to(dtype=input_dtype)
return res
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,26 @@ class QuantizedAddmm(_DispatchMixin, QuantizationMixin, Addmm):
_builtin_torch_fn = torch.addmm


@QuantizationMixin.implements(RmsNorm)
class QuantizedRmsNorm(QuantizationMixin, RmsNorm):
"""Custom module for RmsNorm"""
# pylint: disable=arguments-differ
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for RmsNorm
"""
if self.input_quantizers[0]:
x = self.input_quantizers[0](x)

with self._patch_quantized_parameters():
out = super().forward(x)

if self.output_quantizers[0]:
out = self.output_quantizers[0](out)

return out


# @QuantizationMixin.implements(Square)
# class QuantizedSquare(_DispatchMixin, QuantizationMixin, Square):
# """ Quantized Square """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,7 @@ def _create_quantized_module(module):
# (lambda: custom.Normalize(), lambda: ...),
# (lambda: custom.Pad(), lambda: ...),
# (lambda: custom.GridSample(), lambda: ...),
(lambda: custom.RmsNorm([5, 2, 3], [2], 1e-5), lambda: (randn(5, 2, 3)))
]))
def test_default_kernels(module_factory, input_factory):
module = module_factory()
Expand Down

0 comments on commit f0b3b0f

Please sign in to comment.