Skip to content

Commit

Permalink
Weight name preservation for CrossAttention (#945)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #945

1. Make sure lowering_utils capture all constants, including submodule's.
2. Add name for fx2ait's conversion of Attention Modules.

Reviewed By: khabinov

Differential Revision: D50060817

fbshipit-source-id: 696430bf40794cb1a76928069d68ca399077d26a
  • Loading branch information
muchulee8 authored and facebook-github-bot committed Oct 10, 2023
1 parent 7e0da3f commit 49e08d2
Showing 1 changed file with 40 additions and 34 deletions.
74 changes: 40 additions & 34 deletions fx2ait/fx2ait/converters/ait_module_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import OrderedDict
from typing import Any, Dict, Tuple

import numpy as np

import torch
from aitemplate.backend.target import Target
from aitemplate.compiler.base import _TorchConstantTensorData
from aitemplate.compiler.base import _TorchConstantTensorData, Tensor
from aitemplate.compiler.ops import chunk
from aitemplate.frontend import nn
from fx2ait.utils import make_str_ait_friendly
from torch.fx.node import Argument

from .ait_converters import ConverterOutput
Expand Down Expand Up @@ -62,13 +63,44 @@ def multi_head_attention_module(
)

# Bind constant tensor for MHA module
mapped_params = _map_ait_pt_params(attn, submod)
ait_params = dict(attn.named_parameters())
for name, data in mapped_params.items():
ait_tensor = ait_params[name].tensor()
ait_data = _TorchConstantTensorData(data.contiguous().cuda().half())
ait_tensor._bind_data(ait_data)
qkv_weight, qkv_bias = None, None
for k, v in submod.named_parameters():
ait_data = _TorchConstantTensorData(v.data.contiguous().cuda().half())
if "in_proj" in k:
if "weight" in k:
qkv_weight = Tensor(
shape=v.shape,
dtype="float16",
name=make_str_ait_friendly(f"{target}.{k}"),
)
qkv_weight._bind_data(ait_data)
elif "bias" in k:
qkv_bias = Tensor(
shape=v.shape,
dtype="float16",
name=make_str_ait_friendly(f"{target}.{k}"),
)
qkv_bias._bind_data(ait_data)
elif "out_proj" in k:
if "weight" in k:
tensor = attn.proj.weight.tensor()
elif "bias" in k:
tensor = attn.proj.bias.tensor()
tensor._attrs["name"] = make_str_ait_friendly(f"{target}.{k}")
tensor._bind_data(ait_data)

# Swap out qkv tensor used by nn.CrossAttention.
q_w, k_w, v_w = chunk()(qkv_weight, 3)
q_b, k_b, v_b = chunk()(qkv_bias, 3)

attn.proj_q.weight._tensor = q_w
attn.proj_k.weight._tensor = k_w
attn.proj_v.weight._tensor = v_w
attn.proj_q.bias._tensor = q_b
attn.proj_k.bias._tensor = k_b
attn.proj_v.bias._tensor = v_b

ait_params = dict(attn.named_parameters())
if "cu_length" in ait_params:
ait_tensor = ait_params["cu_length"].tensor()
cu_len = np.cumsum([0] + [seq_len.value()] * bsz.value()).astype("int32")
Expand All @@ -79,29 +111,3 @@ def multi_head_attention_module(
res = attn(query, key, value)
# make output of MHA a list to match the output type of pytorch MHA
return [res]


def _map_ait_pt_params(ait_module, pt_module):
ait_params = dict(ait_module.named_parameters())
mapped_pt_params = OrderedDict()
for pt_name, pt_param in pt_module.named_parameters():
ait_friendly_name = (
pt_name.replace("in_proj", "qkv")
.replace("out_proj", "proj")
.replace("_", ".")
)
if ait_friendly_name in ait_params:
mapped_pt_params[ait_friendly_name] = pt_param.data
elif "in_proj" in pt_name:
# set constant for cross attention
if len(pt_param.shape) == 2:
w_q, w_k, w_v = pt_param.chunk(3)
mapped_pt_params["proj_q.weight"] = w_q
mapped_pt_params["proj_k.weight"] = w_k
mapped_pt_params["proj_v.weight"] = w_v
else:
b_q, b_k, b_v = pt_param.chunk(3)
mapped_pt_params["proj_q.bias"] = b_q
mapped_pt_params["proj_k.bias"] = b_k
mapped_pt_params["proj_v.bias"] = b_v
return mapped_pt_params

0 comments on commit 49e08d2

Please sign in to comment.