Skip to content

Commit

Permalink
Weight name preservation for CrossAttention
Browse files Browse the repository at this point in the history
Summary:
1. Make sure lowering_utils capture all constants, including submodule's.
2. Add name for fx2ait's conversion of Attention Modules.

Differential Revision: D50060817

fbshipit-source-id: a8ba5b06ffc869370f98ff5b7efdc7f1e905e492
  • Loading branch information
muchulee8 authored and facebook-github-bot committed Oct 7, 2023
1 parent 74d813f commit 1030bdf
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 40 deletions.
11 changes: 3 additions & 8 deletions fx2ait/fx2ait/converters/ait_module_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def multi_head_attention_module(
num_heads=submod.num_heads,
qkv_bias=True,
has_residual=False,
name=target,
)

# Bind constant tensor for MHA module
Expand Down Expand Up @@ -95,13 +96,7 @@ def _map_ait_pt_params(ait_module, pt_module):
elif "in_proj" in pt_name:
# set constant for cross attention
if len(pt_param.shape) == 2:
w_q, w_k, w_v = pt_param.chunk(3)
mapped_pt_params["proj_q.weight"] = w_q
mapped_pt_params["proj_k.weight"] = w_k
mapped_pt_params["proj_v.weight"] = w_v
mapped_pt_params["proj_qkv_weight"] = pt_param.data
else:
b_q, b_k, b_v = pt_param.chunk(3)
mapped_pt_params["proj_q.bias"] = b_q
mapped_pt_params["proj_k.bias"] = b_k
mapped_pt_params["proj_v.bias"] = b_v
mapped_pt_params["proj_qkv_bias"] = pt_param.data
return mapped_pt_params
3 changes: 2 additions & 1 deletion fx2ait/fx2ait/fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# @manual=//aitemplate/AITemplate/python/aitemplate:aitemplate
from aitemplate.testing import detect_target
from aitemplate.utils.misc import make_str_ait_friendly
from fx2ait.ait_module import ARG_SPLITTER_KEYWORD
from .converters.ait_converters import * # isort:skip # noqa: F401 F403
from .converters.aten2ait_converters import * # isort:skip # noqa: F401 F403
Expand All @@ -39,7 +40,7 @@
from .converters.converter_registry import AIT_CONVERTERS
from .tensor_spec import TensorSpec

from .utils import dtype_to_str, make_str_ait_friendly
from .utils import dtype_to_str

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down
10 changes: 0 additions & 10 deletions fx2ait/fx2ait/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,3 @@ def dtype_to_str(dtype):
if dtype is None:
return "float16"
return torch_dtype_to_string(dtype)


def make_str_ait_friendly(s: str) -> str:
if s.isalnum():
ret = s
else:
ret = "".join(c if c.isalnum() else "_" for c in s)
if ret[0].isdigit():
ret = "_" + ret
return ret
58 changes: 37 additions & 21 deletions python/aitemplate/frontend/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from aitemplate.frontend.nn.module import Module
from aitemplate.frontend.nn.parameter import Parameter
from aitemplate.testing import detect_target
from aitemplate.utils.misc import make_str_ait_friendly


class FlashAttention(Module):
Expand Down Expand Up @@ -324,6 +325,7 @@ def __init__(
has_residual=True,
causal=False,
dtype="float16",
name=None,
):
super().__init__()
assert (
Expand All @@ -333,41 +335,55 @@ def __init__(
self.causal = causal
self.has_residual = has_residual
self.dim = dim
self.has_bias = qkv_bias

self.op = ops.mem_eff_attention(causal=causal)

self.proj_q = Linear(
dim,
dim,
bias=qkv_bias,
dtype=dtype,
)
self.proj_k = Linear(
dim,
dim,
bias=qkv_bias,
dtype=dtype,
)
self.proj_v = Linear(
dim,
dim,
bias=qkv_bias,
dtype=dtype,
param_name = name
if name is not None:
param_name = make_str_ait_friendly(name + ".in_proj_weight")
self.proj_qkv_weight = Parameter(
shape=[3 * dim, dim], dtype=dtype, name=param_name
)
if self.has_bias:
if name is not None:
param_name = make_str_ait_friendly(name + ".in_proj_bias")
self.proj_qkv_bias = Parameter(
shape=[3 * dim], dtype=dtype, name=param_name
)

self.attn_drop = Dropout(attn_drop, dtype=dtype)
self.proj = Linear(
dim, dim, specialization="add" if has_residual else None, dtype=dtype
)
if name is not None:
self.proj.weight.tensor()._attrs["name"] = make_str_ait_friendly(
name + ".out_proj_weight"
)
self.proj.bias.tensor()._attrs["name"] = make_str_ait_friendly(
name + ".out_proj_bias"
)
self.proj_drop = Dropout(proj_drop, dtype=dtype)

def linear(self, x, weight, bias):
USE_CUDA = detect_target().name() == "cuda"
if USE_CUDA:
x = ops.reshape()(x, [-1, self.dim])

if bias is None:
return ops.gemm_rcr()(x, weight)
return ops.gemm_rcr_bias()(x, weight, bias)

def attention(self, q, k, v):
batch = q.shape()[0]
head_dim = self.dim // self.num_heads

query = self.proj_q(q)
key = self.proj_k(k)
value = self.proj_v(v)
q_w, k_w, v_w = ops.chunk()(self.proj_qkv_weight.tensor(), 3)
q_b, k_b, v_b = None, None, None
if self.has_bias:
q_b, k_b, v_b = ops.chunk()(self.proj_qkv_bias.tensor(), 3)
query = self.linear(q, q_w, q_b)
key = self.linear(k, k_w, k_b)
value = self.linear(v, v_w, v_b)

query = ops.permute()(
ops.reshape()(query, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3]
Expand Down
11 changes: 11 additions & 0 deletions python/aitemplate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def short_str(s, length=8) -> str:
return hash_str[0:length]


def make_str_ait_friendly(s: str) -> str:
if s.isalnum():
return s
else:
ret = "".join(c if c.isalnum() else "_" for c in s)

if ret[0].isdigit():
ret = "_" + ret
return ret


def callstack_stats(enable=False):
if enable:

Expand Down

0 comments on commit 1030bdf

Please sign in to comment.