Skip to content

Commit

Permalink
Merge pull request apache#50 from sunggg/bugfix/2024-Feb/fix-cutlass-…
Browse files Browse the repository at this point in the history
…rms-norm

[CUTLASS] Use operator pattern for RMS Norm to conform with SLM model definition
  • Loading branch information
sunggg committed Feb 28, 2024
2 parents 4d56c71 + 2914c0f commit 99b0e62
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 29 deletions.
9 changes: 0 additions & 9 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,21 +464,12 @@ def layer_norm_pattern():
]


def _check_rms_norm(ctx: PatternCheckContext) -> bool:
rms_norm = ctx.annotated_expr["rms_norm"]
if "rms_norm" not in rms_norm.args[0].name_hint:
return False

return True


def rms_norm_pattern():
"""Create a RMS norm pattern for CUTLASS."""
return [
(
"cutlass.rms_norm",
*make_rms_norm_pattern(),
_check_rms_norm,
),
]

Expand Down
6 changes: 2 additions & 4 deletions python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,8 @@ def make_rms_norm_pattern():
"""Create a layer norm pattern."""
inp = wildcard()
weight = wildcard()
gv = GlobalVarPattern()
out = is_op("relax.call_tir")(gv, TuplePattern([inp, weight]))
annotations = {"gv": gv, "inp": inp, "rms_norm": out}
return out, annotations
out = is_op("relax.nn.rms_norm")(inp, weight)
return out, {}


def make_attention_rewrite_pattern(
Expand Down
5 changes: 1 addition & 4 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ def rms_norm(
.. math::
out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias
out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight
Parameters
----------
Expand All @@ -1071,9 +1071,6 @@ def rms_norm(
weight : relax.Expr
The scale factor.
bias : relax.Expr
The offset factor.
axes : Union[int, List[int]]
The axes that along which the normalization is applied.
Expand Down
39 changes: 27 additions & 12 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def main(
) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
cls = Conv2dReLU_composite_annotated
with R.dataflow():
gv: R.Tensor(
(1, 64, 56, 56), dtype="float32"
) = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)
gv: R.Tensor((1, 64, 56, 56), dtype="float32") = (
cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)
)
R.output(gv)
return gv

Expand Down Expand Up @@ -120,12 +120,12 @@ def main(
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
cls = Conv2dReLUx2Partitioned
with R.dataflow():
lv: R.Tensor(
(1, 64, 56, 56), dtype="float32"
) = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1)
gv: R.Tensor(
(1, 64, 54, 54), dtype="float32"
) = cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2)
lv: R.Tensor((1, 64, 56, 56), dtype="float32") = (
cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1)
)
gv: R.Tensor((1, 64, 54, 54), dtype="float32") = (
cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2)
)
R.output(gv)
return gv

Expand Down Expand Up @@ -235,9 +235,9 @@ def main(
lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d(
data, weight1
)
gv: R.Tensor(
(1, 64, 54, 54), dtype="float32"
) = cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2)
gv: R.Tensor((1, 64, 54, 54), dtype="float32") = (
cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2)
)
R.output(gv)
return gv

Expand Down Expand Up @@ -1024,6 +1024,21 @@ def main(
assert "fused_relax_matmul_relax_add_relax_add_cutlass" in func_names


def test_rms_norm():
@tvm.script.ir_module
class RMSNorm:
@R.function
def main(data: R.Tensor((1, 64, 3), "float16"), weight: R.Tensor((3,), "float16")):
with R.dataflow():
out = R.nn.rms_norm(data, weight)
R.output(out)
return out

mod = partition_for_cutlass(RMSNorm)
func_names = [name.name_hint for (name, _) in mod.functions.items()]
assert "fused_relax_nn_rms_norm_cutlass" in func_names


def test_intermediate_var_to_var_binding():
"""test the intermediate binding y1 will break the fusion"""

Expand Down

0 comments on commit 99b0e62

Please sign in to comment.