Skip to content

Commit

Permalink
Weight name preservation for CrossAttention (#945)
Browse files Browse the repository at this point in the history
Summary:

1. Make sure lowering_utils capture all constants, including submodule's.
2. Add name for fx2ait's conversion of Attention Modules.

Differential Revision: D50060817
  • Loading branch information
muchulee8 authored and facebook-github-bot committed Oct 9, 2023
1 parent 7e0da3f commit 8531300
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 44 deletions.
74 changes: 41 additions & 33 deletions fx2ait/fx2ait/converters/ait_module_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import torch
from aitemplate.backend.target import Target
from aitemplate.compiler.base import _TorchConstantTensorData
from aitemplate.compiler.base import _TorchConstantTensorData, Tensor
from aitemplate.compiler.ops import chunk
from aitemplate.frontend import nn
from aitemplate.utils.misc import make_str_ait_friendly
from torch.fx.node import Argument

from .ait_converters import ConverterOutput
Expand Down Expand Up @@ -62,13 +64,45 @@ def multi_head_attention_module(
)

# Bind constant tensor for MHA module
mapped_params = _map_ait_pt_params(attn, submod)
ait_params = dict(attn.named_parameters())
for name, data in mapped_params.items():
ait_tensor = ait_params[name].tensor()
ait_data = _TorchConstantTensorData(data.contiguous().cuda().half())
ait_tensor._bind_data(ait_data)
qkv_weight, qkv_bias = None, None
proj_weight, proj_bias = None, None

Check failure on line 68 in fx2ait/fx2ait/converters/ait_module_converters.py

View workflow job for this annotation

GitHub Actions / build (3.8)

F841 local variable 'proj_weight' is assigned to but never used

Check failure on line 68 in fx2ait/fx2ait/converters/ait_module_converters.py

View workflow job for this annotation

GitHub Actions / build (3.8)

F841 local variable 'proj_bias' is assigned to but never used
for k, v in submod.named_parameters():
ait_data = _TorchConstantTensorData(v.data.contiguous().cuda().half())
if "in_proj" in k:
if "weight" in k:
qkv_weight = Tensor(
shape=v.shape,
dtype="float16",
name=make_str_ait_friendly(f"{target}.{k}"),
)
qkv_weight._bind_data(ait_data)
elif "bias" in k:
qkv_bias = Tensor(
shape=v.shape,
dtype="float16",
name=make_str_ait_friendly(f"{target}.{k}"),
)
qkv_bias._bind_data(ait_data)
elif "out_proj" in k:
if "weight" in k:
tensor = attn.proj.weight.tensor()
elif "bias" in k:
tensor = attn.proj.bias.tensor()
tensor._attrs["name"] = make_str_ait_friendly(f"{target}.{k}")
tensor._bind_data(ait_data)

# Swap out qkv tensor used by nn.CrossAttention.
q_w, k_w, v_w = chunk()(qkv_weight, 3)
q_b, k_b, v_b = chunk()(qkv_bias, 3)

attn.proj_q.weight._tensor = q_w
attn.proj_k.weight._tensor = k_w
attn.proj_v.weight._tensor = v_w
attn.proj_q.bias._tensor = q_b
attn.proj_k.bias._tensor = k_b
attn.proj_v.bias._tensor = v_b

ait_params = dict(attn.named_parameters())
if "cu_length" in ait_params:
ait_tensor = ait_params["cu_length"].tensor()
cu_len = np.cumsum([0] + [seq_len.value()] * bsz.value()).astype("int32")
Expand All @@ -79,29 +113,3 @@ def multi_head_attention_module(
res = attn(query, key, value)
# make output of MHA a list to match the output type of pytorch MHA
return [res]


def _map_ait_pt_params(ait_module, pt_module):
ait_params = dict(ait_module.named_parameters())
mapped_pt_params = OrderedDict()
for pt_name, pt_param in pt_module.named_parameters():
ait_friendly_name = (
pt_name.replace("in_proj", "qkv")
.replace("out_proj", "proj")
.replace("_", ".")
)
if ait_friendly_name in ait_params:
mapped_pt_params[ait_friendly_name] = pt_param.data
elif "in_proj" in pt_name:
# set constant for cross attention
if len(pt_param.shape) == 2:
w_q, w_k, w_v = pt_param.chunk(3)
mapped_pt_params["proj_q.weight"] = w_q
mapped_pt_params["proj_k.weight"] = w_k
mapped_pt_params["proj_v.weight"] = w_v
else:
b_q, b_k, b_v = pt_param.chunk(3)
mapped_pt_params["proj_q.bias"] = b_q
mapped_pt_params["proj_k.bias"] = b_k
mapped_pt_params["proj_v.bias"] = b_v
return mapped_pt_params
3 changes: 2 additions & 1 deletion fx2ait/fx2ait/fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# @manual=//aitemplate/AITemplate/python/aitemplate:aitemplate
from aitemplate.testing import detect_target
from aitemplate.utils.misc import make_str_ait_friendly
from fx2ait.ait_module import ARG_SPLITTER_KEYWORD
from .converters.ait_converters import * # isort:skip # noqa: F401 F403
from .converters.aten2ait_converters import * # isort:skip # noqa: F401 F403
Expand All @@ -39,7 +40,7 @@
from .converters.converter_registry import AIT_CONVERTERS
from .tensor_spec import TensorSpec

from .utils import dtype_to_str, make_str_ait_friendly
from .utils import dtype_to_str

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down
10 changes: 0 additions & 10 deletions fx2ait/fx2ait/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,3 @@ def dtype_to_str(dtype):
if dtype is None:
return "float16"
return torch_dtype_to_string(dtype)


def make_str_ait_friendly(s: str) -> str:
if s.isalnum():
ret = s
else:
ret = "".join(c if c.isalnum() else "_" for c in s)
if ret[0].isdigit():
ret = "_" + ret
return ret
11 changes: 11 additions & 0 deletions python/aitemplate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def short_str(s, length=8) -> str:
return hash_str[0:length]


def make_str_ait_friendly(s: str) -> str:
if s.isalnum():
return s
else:
ret = "".join(c if c.isalnum() else "_" for c in s)

if ret[0].isdigit():
ret = "_" + ret
return ret


def callstack_stats(enable=False):
if enable:

Expand Down

0 comments on commit 8531300

Please sign in to comment.