diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index de7eb54b991b..6f47b44d1766 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -464,21 +464,12 @@ def layer_norm_pattern(): ] -def _check_rms_norm(ctx: PatternCheckContext) -> bool: - rms_norm = ctx.annotated_expr["rms_norm"] - if "rms_norm" not in rms_norm.args[0].name_hint: - return False - - return True - - def rms_norm_pattern(): """Create a RMS norm pattern for CUTLASS.""" return [ ( "cutlass.rms_norm", *make_rms_norm_pattern(), - _check_rms_norm, ), ] diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 23de175b24f6..e7bf3d501211 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -330,10 +330,8 @@ def make_rms_norm_pattern(): """Create a layer norm pattern.""" inp = wildcard() weight = wildcard() - gv = GlobalVarPattern() - out = is_op("relax.call_tir")(gv, TuplePattern([inp, weight])) - annotations = {"gv": gv, "inp": inp, "rms_norm": out} - return out, annotations + out = is_op("relax.nn.rms_norm")(inp, weight) + return out, {} def make_attention_rewrite_pattern( diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 5adf38d7d642..71fc8661dec4 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1061,7 +1061,7 @@ def rms_norm( .. math:: - out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias + out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight Parameters ---------- @@ -1071,9 +1071,6 @@ def rms_norm( weight : relax.Expr The scale factor. - bias : relax.Expr - The offset factor. - axes : Union[int, List[int]] The axes that along which the normalization is applied. diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index bd434864a081..e093e6022e0a 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -56,9 +56,9 @@ def main( ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): cls = Conv2dReLU_composite_annotated with R.dataflow(): - gv: R.Tensor( - (1, 64, 56, 56), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1) + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1) + ) R.output(gv) return gv @@ -120,12 +120,12 @@ def main( ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): cls = Conv2dReLUx2Partitioned with R.dataflow(): - lv: R.Tensor( - (1, 64, 56, 56), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) - gv: R.Tensor( - (1, 64, 54, 54), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) + ) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) + ) R.output(gv) return gv @@ -235,9 +235,9 @@ def main( lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( data, weight1 ) - gv: R.Tensor( - (1, 64, 54, 54), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2) + ) R.output(gv) return gv @@ -1024,6 +1024,21 @@ def main( assert "fused_relax_matmul_relax_add_relax_add_cutlass" in func_names +def test_rms_norm(): + @tvm.script.ir_module + class RMSNorm: + @R.function + def main(data: R.Tensor((1, 64, 3), "float16"), weight: R.Tensor((3,), "float16")): + with R.dataflow(): + out = R.nn.rms_norm(data, weight) + R.output(out) + return out + + mod = partition_for_cutlass(RMSNorm) + func_names = [name.name_hint for (name, _) in mod.functions.items()] + assert "fused_relax_nn_rms_norm_cutlass" in func_names + + def test_intermediate_var_to_var_binding(): """test the intermediate binding y1 will break the fusion"""