Skip to content

Commit

Permalink
Add support for Relax RMS norm in cutlass offloading
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 13, 2024
1 parent 2a3007d commit 1ae8904
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
16 changes: 14 additions & 2 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
make_matmul_pattern,
make_residual_block_pattern,
make_rms_norm_pattern,
make_rms_norm_tir_pattern,
make_stacked_attention_pattern,
)
from ..utils import has_leaking_intermediate_variables
Expand Down Expand Up @@ -469,7 +470,7 @@ def layer_norm_pattern():
]


def _check_rms_norm(ctx: PatternCheckContext) -> bool:
def _check_rms_norm_tir(ctx: PatternCheckContext) -> bool:
rms_norm = ctx.annotated_expr["rms_norm"]
if "rms_norm" not in rms_norm.args[0].name_hint:
return False
Expand All @@ -483,7 +484,17 @@ def rms_norm_pattern():
(
"cutlass.rms_norm",
*make_rms_norm_pattern(),
_check_rms_norm,
),
]


def rms_norm_tir_pattern():
"""Create a RMS norm pattern for CUTLASS."""
return [
(
"cutlass.rms_norm_tir",
*make_rms_norm_tir_pattern(),
_check_rms_norm_tir,
),
]

Expand Down Expand Up @@ -512,6 +523,7 @@ def attention_rewrite_patterns():
*attention_patterns(),
*layer_norm_pattern(),
*rms_norm_pattern(),
*rms_norm_tir_pattern(),
]
)

Expand Down
10 changes: 9 additions & 1 deletion python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def make_layer_norm_pattern():
return is_op("relax.nn.layer_norm")(inp, gamma, beta), {}


def make_rms_norm_pattern():
def make_rms_norm_tir_pattern():
"""Create a layer norm pattern."""
inp = wildcard()
weight = wildcard()
Expand All @@ -336,6 +336,14 @@ def make_rms_norm_pattern():
return out, annotations


def make_rms_norm_pattern():
"""Create a layer norm pattern."""
inp = wildcard()
weight = wildcard()
out = is_op("relax.nn.rms_norm")(inp, weight)
return out, {}


def make_attention_rewrite_pattern(
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False
):
Expand Down

0 comments on commit 1ae8904

Please sign in to comment.