diff --git a/fx2ait/fx2ait/converters/ait_module_converters.py b/fx2ait/fx2ait/converters/ait_module_converters.py index 892f05bfc..fad6920b4 100644 --- a/fx2ait/fx2ait/converters/ait_module_converters.py +++ b/fx2ait/fx2ait/converters/ait_module_converters.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from collections import OrderedDict from typing import Any, Dict, Tuple import numpy as np import torch from aitemplate.backend.target import Target -from aitemplate.compiler.base import _TorchConstantTensorData +from aitemplate.compiler.base import _TorchConstantTensorData, Tensor +from aitemplate.compiler.ops import chunk from aitemplate.frontend import nn +from fx2ait.utils import make_str_ait_friendly from torch.fx.node import Argument from .ait_converters import ConverterOutput @@ -62,13 +63,44 @@ def multi_head_attention_module( ) # Bind constant tensor for MHA module - mapped_params = _map_ait_pt_params(attn, submod) - ait_params = dict(attn.named_parameters()) - for name, data in mapped_params.items(): - ait_tensor = ait_params[name].tensor() - ait_data = _TorchConstantTensorData(data.contiguous().cuda().half()) - ait_tensor._bind_data(ait_data) + qkv_weight, qkv_bias = None, None + for k, v in submod.named_parameters(): + ait_data = _TorchConstantTensorData(v.data.contiguous().cuda().half()) + if "in_proj" in k: + if "weight" in k: + qkv_weight = Tensor( + shape=v.shape, + dtype="float16", + name=make_str_ait_friendly(f"{target}.{k}"), + ) + qkv_weight._bind_data(ait_data) + elif "bias" in k: + qkv_bias = Tensor( + shape=v.shape, + dtype="float16", + name=make_str_ait_friendly(f"{target}.{k}"), + ) + qkv_bias._bind_data(ait_data) + elif "out_proj" in k: + if "weight" in k: + tensor = attn.proj.weight.tensor() + elif "bias" in k: + tensor = attn.proj.bias.tensor() + tensor._attrs["name"] = make_str_ait_friendly(f"{target}.{k}") + tensor._bind_data(ait_data) + + # Swap out qkv tensor used by nn.CrossAttention. + q_w, k_w, v_w = chunk()(qkv_weight, 3) + q_b, k_b, v_b = chunk()(qkv_bias, 3) + + attn.proj_q.weight._tensor = q_w + attn.proj_k.weight._tensor = k_w + attn.proj_v.weight._tensor = v_w + attn.proj_q.bias._tensor = q_b + attn.proj_k.bias._tensor = k_b + attn.proj_v.bias._tensor = v_b + ait_params = dict(attn.named_parameters()) if "cu_length" in ait_params: ait_tensor = ait_params["cu_length"].tensor() cu_len = np.cumsum([0] + [seq_len.value()] * bsz.value()).astype("int32") @@ -79,29 +111,3 @@ def multi_head_attention_module( res = attn(query, key, value) # make output of MHA a list to match the output type of pytorch MHA return [res] - - -def _map_ait_pt_params(ait_module, pt_module): - ait_params = dict(ait_module.named_parameters()) - mapped_pt_params = OrderedDict() - for pt_name, pt_param in pt_module.named_parameters(): - ait_friendly_name = ( - pt_name.replace("in_proj", "qkv") - .replace("out_proj", "proj") - .replace("_", ".") - ) - if ait_friendly_name in ait_params: - mapped_pt_params[ait_friendly_name] = pt_param.data - elif "in_proj" in pt_name: - # set constant for cross attention - if len(pt_param.shape) == 2: - w_q, w_k, w_v = pt_param.chunk(3) - mapped_pt_params["proj_q.weight"] = w_q - mapped_pt_params["proj_k.weight"] = w_k - mapped_pt_params["proj_v.weight"] = w_v - else: - b_q, b_k, b_v = pt_param.chunk(3) - mapped_pt_params["proj_q.bias"] = b_q - mapped_pt_params["proj_k.bias"] = b_k - mapped_pt_params["proj_v.bias"] = b_v - return mapped_pt_params