Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pnnx关于多头注意力识别问题 #5662

Open
futz12 opened this issue Aug 31, 2024 · 2 comments
Open

pnnx关于多头注意力识别问题 #5662

futz12 opened this issue Aug 31, 2024 · 2 comments

Comments

@futz12
Copy link

futz12 commented Aug 31, 2024

据我观测pnnx的代码,对于多头注意力的这个layer识别上有一些缺陷。

pnnx目前只能支持torch自带的多头注意力,这对于相关落地似乎不是很方便(例如transformer.py中bert的多头注意力)。

主要问题集中在两个方面,一个是缩放点积注意力需要填写scale参数,但是在torch的文档中scale是可省却的(默认值是1 / math.sqrt(query.size(-1)) )。

另一方面是 不同库对于多头注意力 使用的 view-reshape,permute-transpose 这类同义算子没法做到很好的识别。

希望能为 ncnn 提供缩放点积注意力算子。

@futz12
Copy link
Author

futz12 commented Aug 31, 2024

例如 BertSdpaSelfAttention 的 IR 应该是这样写

15 14
pnnx.Input              input       0 1 input
nn.Linear               op_0        1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
Tensor.view          op_3        1 1 q 10 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute          op_6        1 1 10 16 dims=(0,2,1,3)
nn.Linear               op_1        1 1 input k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
Tensor.view          op_4        1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute          op_7        1 1 12 17 dims=(0,2,1,3)
nn.Linear               op_2        1 1 input v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.view          op_5        1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute          op_8        1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None
Tensor.transpose          op_10       1 1 19 20 dim0=%dim0 dim1=%dim1
Tensor.reshape          op_11       1 1 20 21 shape=(%batch,%size,%embed_dim)
nn.Linear               out_proj    1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output             output      1 0 out
)PNNXIR

@wzyforgit
Copy link
Contributor

因为pnnx的p的意思就是pytorch(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants