Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于Attention里的Forward行为不一致的问题 #146

Closed
rentainhe opened this issue Feb 28, 2022 · 16 comments
Closed

关于Attention里的Forward行为不一致的问题 #146

rentainhe opened this issue Feb 28, 2022 · 16 comments
Assignees
Labels

Comments

@rentainhe
Copy link
Contributor

rentainhe commented Feb 28, 2022

Contents

这个issue主要记录LiBai中MultiHeadAttention与timm下的实现不一致问题,以及讨论对应的解决方案

问题描述

LiBai下的MultiHeadAttention与timm中实现的Attention的主要区别在于Forward部分,下面是简化的代码实现:

  • LiBai Attention的实现
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        # LiBai中关于query key value的切分
        # =======================
        qkv = self.qkv(x)
        qkv = qkv.view(B, -1, self.num_heads, 3 * C // self.num_heads)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = torch.chunk(qkv, chunks=3, dim=-1)
        # =======================
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
  • timm Attention的实现
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        # timm中关于query, key, value的切分
        # =======================
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
        # =======================
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

问题主要在于query, key, value切分方式的不一致,带来了,即使载入相同的weight,也没办法得到一致的forward结果,但是视觉transformer的实现一般都是参考timm下的Attention,所以用libai内置的layer去复现ViT模型的时候会遇到,加载一样的权重,无法推理出一样的结果的尴尬情况

可以参考的解决方案

  1. 增加一个timmAttention的代码,然后修改TransformerLayer,在TransformerLayer中添加一个Attention的参数,可以指定用的Attention模块,这样做的好处是:如果用户想换Attention模块,比如我不用vanilla Attention,我使用Linear Attention,或者任意Attention Block,Performer等等,可以直接在参数里指定Attention模块是啥,并且由于我们LazyCall的机制,也可以比较灵活的搭配,这样做的麻烦的点在于,LiBai Attention中指定了一些参数,这些参数不一定可以在每个Attention模块里统一,这里需要统一输入
  2. 不去修改LiBai里的代码,在子projects下自己重修改一套,类似MAE等项目,自己维护一套Attention代码,这样做就是比较奇怪,个人倾向于前一种方法

在不修改libai的代码下的权重转换思路

import torch
import torch.nn.functional as F

b = 2
n = 5
num_heads = 4
cn = 6
c = cn*num_heads
x = torch.randn(b*n*c).view(b, n, c)
weight = torch.randn(c*c*3)
bias = torch.rand(c*3)


weight1 = weight.view(c*3, c)
weight1 = weight1.view(3, num_heads, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)
# weight2 = weight.view(num_heads, 3, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)
"""
(head, 3, head_size, hidden_size) -> (3, head, head_size, hidden_size) -> (3 * head * head_size, hidden_size)
"""
weight2 = weight.view(c*3, c)

bias1 = bias
bias1 = bias1.view(3, num_heads, cn).permute(1, 0, 2).contiguous().view(c*3)
# bias2 = bias.view(num_heads, 3, cn).permute(1, 0, 2).contiguous().view(c*3)
bias2 = bias

qkv1 = F.linear(x, weight1, bias=None)
qkv2 = F.linear(x, weight2, bias=None)


# libai version
qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)


# timm version
qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]

print((q1==q2))     # tensor(True)
print((k1==k2))     # tensor(True)
print((v1==v2))     # tensor(True)
@yuanms2
Copy link

yuanms2 commented Feb 28, 2022

这两种写法数学上应该等价吧? 之前vit 时好似讨论过二者的区别。

@kaijieshi7
Copy link

这两种写法数学上应该等价吧? 之前vit 时好似讨论过二者的区别。

结果的形状是等价的,数值不等价。

@rentainhe
Copy link
Contributor Author

这两种写法数学上应该等价吧? 之前vit 时好似讨论过二者的区别。

对于Self-Attention这个模块来说,其最终的效果是等价的,只是切分出来的query key和value不同,如果从头训练肯定可以得到类似的精度最终,只是如果加载别人的权重的话,这里切分的query,key,value不同,就无法复现别人的结果

@kaijieshi7
Copy link

下述代码,给定相同的qkv,两种切分方法得到的(q1, k1, v1)和(q2, k2, v2)。q1.sum()的结果和q2.sum()的结果不同,那么两种切分方法有一种的q,k,v和想法上的qkv是不一样的。

import torch

b = 2
n = 5
c = 6*4
num_heads = 4

qkv1 = torch.arange(b*n*c*3)
qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)
q1_s = q1.sum()

qkv2 = torch.arange(b*n*c*3)
qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
q2_s = q2.sum()
print(q1_s)
print(q2_s)

@rentainhe
Copy link
Contributor Author

下述代码,给定相同的qkv,两种切分方法得到的(q1, k1, v1)和(q2, k2, v2)。q1.sum()的结果和q2.sum()的结果不同,那么两种切分方法有一种的q,k,v和想法上的qkv是不一样的。

import torch

b = 2
n = 5
c = 6*4
num_heads = 4

qkv1 = torch.arange(b*n*c*3)
qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)
q1_s = q1.sum()

qkv2 = torch.arange(b*n*c*3)
qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
q2_s = q2.sum()
print(q1_s)
print(q2_s)

不能说不一样吧,作为SelfAttention来说都是一样的,我只需要得到qkv然后去计算attention就行了,作用是一样的,切分方式上不太一样而已,不存在想法上的qkv,抽象出来我只需要根据输入得到query, key, value就行了,至于是通过怎样reshape,怎么样view得到的,无所谓其实

@kaijieshi7
Copy link

下面的代码提供了一个转换权重的思路,转完之后可以得到相同的结果。

import torch
import torch.nn.functional as F

b = 2
n = 5
num_heads = 4
cn = 6
c = cn*num_heads
x = torch.arange(b*n*c).view(b, n, c)
weight = torch.arange(c*c*3)
weight1 = weight.view(c*3, c)
weight2 = weight.view(num_heads, 3, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)
qkv1 = F.linear(x, weight1, bias=None)
qkv2 = F.linear(x, weight2, bias=None)

qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)
q1_s = q1.sum()
k1_s = k1.sum()
v1_s = v1.sum()

qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
q2_s = q2.sum()
k2_s = k2.sum()
v2_s = v2.sum()
print(q1_s==q2_s)
print(k1_s==k2_s)
print(v1_s==v2_s)

@rentainhe
Copy link
Contributor Author

rentainhe commented Mar 1, 2022

下面的代码提供了一个转换权重的思路,转完之后可以得到相同的结果。

import torch
import torch.nn.functional as F

b = 2
n = 5
num_heads = 4
cn = 6
c = cn*num_heads
x = torch.arange(b*n*c).view(b, n, c)
weight = torch.arange(c*c*3)
weight1 = weight.view(c*3, c)
weight2 = weight.view(num_heads, 3, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)
qkv1 = F.linear(x, weight1, bias=None)
qkv2 = F.linear(x, weight2, bias=None)

qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)
q1_s = q1.sum()
k1_s = k1.sum()
v1_s = v1.sum()

qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
q2_s = q2.sum()
k2_s = k2.sum()
v2_s = v2.sum()
print(q1_s==q2_s)
print(k1_s==k2_s)
print(v1_s==v2_s)

是个可以转换的思路,但是还是略微麻烦,容易忽视这块儿的细节,我觉着还是加一个timmAttention比较合理~

@rentainhe rentainhe added the bug Something isn't working label Mar 1, 2022
@yuanms2
Copy link

yuanms2 commented Mar 4, 2022

我说的“数学上等价” 就包括切分方式一样,计算结果不一样。

从上面的讨论中看,这两种切分方式不一样,那就是数学上不一样,仅仅是功能差不多而已。

那我们应该完全抛弃掉第一种写法(有没有其它的库使用这种写法?) , 采取完全和timm 对齐的方式

@rentainhe
Copy link
Contributor Author

rentainhe commented Mar 4, 2022

我说的“数学上等价” 就包括切分方式一样,计算结果不一样。

从上面的讨论中看,这两种切分方式不一样,那就是数学上不一样,仅仅是功能差不多而已。

那我们应该完全抛弃掉第一种写法(有没有其它的库使用这种写法?) , 采取完全和timm 对齐的方式

其实SA的写法比较多,这边第一种写法是和megatron下对齐的,然后第二种是timm里用的写法(二者都有在一些大的模型库里见到),所以我们折中了一下想把两种写法都保留,主要是为了方便load torch的权重

@yuanms2
Copy link

yuanms2 commented Mar 4, 2022

好的,如果两种方式都有使用的话,那么可以都保留,在名字上做区分,也稍微注释一下。

@rentainhe
Copy link
Contributor Author

好的,如果两种方式都有使用的话,那么可以都保留,在名字上做区分,也稍微注释一下。

好的,应该会在下一版加上timmAttention,然后把代码变得更elegant

@rentainhe
Copy link
Contributor Author

添加 bias 后的对齐代码

# import torch

# mae_dict = torch.load("/home/rentianhe/code/OneFlow-Models/libai/mae_pretrain_vit_base.pth")["model"]
# for key, value in mae_dict.items():
#     print(key)


import torch
import torch.nn.functional as F

b = 2
n = 5
num_heads = 4
cn = 6
c = cn*num_heads
x = torch.randn(b*n*c).view(b, n, c)

weight = torch.randn(c*c*3)
bias = torch.rand(c*3)


weight1 = weight.view(c*3, c)
weight2 = weight.view(num_heads, 3, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)

bias1 = bias
bias2 = bias.view(num_heads, 3, cn).permute(1, 0, 2).contiguous().view(c*3)

qkv1 = F.linear(x, weight1, bias=bias1)
qkv2 = F.linear(x, weight2, bias=bias2)

qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)

qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
print((q1==q2).all())     # tensor(True)
print((k1==k2).all())     # tensor(True)
print((v1==v2).all())     # tensor(True)

@xiezipeng-ML
Copy link
Contributor

添加 bias 后的对齐代码

# import torch

# mae_dict = torch.load("/home/rentianhe/code/OneFlow-Models/libai/mae_pretrain_vit_base.pth")["model"]
# for key, value in mae_dict.items():
#     print(key)


import torch
import torch.nn.functional as F

b = 2
n = 5
num_heads = 4
cn = 6
c = cn*num_heads
x = torch.randn(b*n*c).view(b, n, c)

weight = torch.randn(c*c*3)
bias = torch.rand(c*3)


weight1 = weight.view(c*3, c)
weight2 = weight.view(num_heads, 3, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)

bias1 = bias
bias2 = bias.view(num_heads, 3, cn).permute(1, 0, 2).contiguous().view(c*3)

qkv1 = F.linear(x, weight1, bias=bias1)
qkv2 = F.linear(x, weight2, bias=bias2)

qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)

qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
print((q1==q2).all())     # tensor(True)
print((k1==k2).all())     # tensor(True)
print((v1==v2).all())     # tensor(True)

这里qkv2应该是timm Attention的计算方式,然后是需要把timm的权重加载到libai上来,也就是weight代表timm的预训练模型,所以应该是调整weight1,使得qkv1的计算结果对齐qkv2吧。所以这里不该调整weight2?

@rentainhe
Copy link
Contributor Author

添加 bias 后的对齐代码

# import torch

# mae_dict = torch.load("/home/rentianhe/code/OneFlow-Models/libai/mae_pretrain_vit_base.pth")["model"]
# for key, value in mae_dict.items():
#     print(key)


import torch
import torch.nn.functional as F

b = 2
n = 5
num_heads = 4
cn = 6
c = cn*num_heads
x = torch.randn(b*n*c).view(b, n, c)

weight = torch.randn(c*c*3)
bias = torch.rand(c*3)


weight1 = weight.view(c*3, c)
weight2 = weight.view(num_heads, 3, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)

bias1 = bias
bias2 = bias.view(num_heads, 3, cn).permute(1, 0, 2).contiguous().view(c*3)

qkv1 = F.linear(x, weight1, bias=bias1)
qkv2 = F.linear(x, weight2, bias=bias2)

qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)

qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
print((q1==q2).all())     # tensor(True)
print((k1==k2).all())     # tensor(True)
print((v1==v2).all())     # tensor(True)

这里qkv2应该是timm Attention的计算方式,然后是需要把timm的权重加载到libai上来,也就是weight代表timm的预训练模型,所以应该是调整weight1,使得qkv1的计算结果对齐qkv2吧。所以这里不该调整weight2?

哦哦那可能是这部分确实有问题,加载错了,我这边倒过来试试看

@rentainhe
Copy link
Contributor Author

rentainhe commented Mar 28, 2022

提供一版实现

这版实现已经有点接近了但是好像还是哪里有点问题,帮忙一起检查一下 @xiezipeng-ML @ZiqiuChi

# import torch

# mae_dict = torch.load("/home/rentianhe/code/OneFlow-Models/libai/mae_pretrain_vit_base.pth")["model"]
# for key, value in mae_dict.items():
#     print(key)


import torch
import torch.nn.functional as F

b = 2
n = 5
num_heads = 4
cn = 6
c = cn*num_heads
x = torch.randn(b*n*c).view(b, n, c)
weight = torch.randn(c*c*3)
bias = torch.rand(c*3)


weight1 = weight.view(c*3, c)
weight1 = weight1.view(3, num_heads, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)
# weight2 = weight.view(num_heads, 3, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)
"""
(head, 3, head_size, hidden_size) -> (3, head, head_size, hidden_size) -> (3 * head * head_size, hidden_size)
"""
weight2 = weight.view(c*3, c)

bias1 = bias
bias1 = bias1.view(3, num_heads, cn).permute(1, 0, 2).contiguous().view(c*3)
# bias2 = bias.view(num_heads, 3, cn).permute(1, 0, 2).contiguous().view(c*3)
bias2 = bias

qkv1 = F.linear(x, weight1, bias=None)
qkv2 = F.linear(x, weight2, bias=None)


# libai version
qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)


# timm version
qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]

print((q1==q2))     # tensor(True)
print((k1==k2))     # tensor(True)
print((v1==v2))     # tensor(True)
  • LiBai下的qkv切分
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(batch_size, -1, num_heads, 3 * head_size)
qkv = qkv.permute(0, 2, 1, 3)
qkv = flow.chunk(qkv, chunks=3, dim=-1)
  • timm下的qkv切分
qkv = self.query_key_value(hidden_states)
qkv = qkv.reshape(batch_size, num_seq, 3, num_heads, head_size)
q, k, v = qkv[0], qkv[1], qkv[2]

@rentainhe
Copy link
Contributor Author

提供一版实现

这版实现已经有点接近了但是好像还是哪里有点问题,帮忙一起检查一下 @xiezipeng-ML @ZiqiuChi

# import torch

# mae_dict = torch.load("/home/rentianhe/code/OneFlow-Models/libai/mae_pretrain_vit_base.pth")["model"]
# for key, value in mae_dict.items():
#     print(key)


import torch
import torch.nn.functional as F

b = 2
n = 5
num_heads = 4
cn = 6
c = cn*num_heads
x = torch.randn(b*n*c).view(b, n, c)
weight = torch.randn(c*c*3)
bias = torch.rand(c*3)


weight1 = weight.view(c*3, c)
weight1 = weight1.view(3, num_heads, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)
# weight2 = weight.view(num_heads, 3, cn, c).permute(1, 0, 2, 3).contiguous().view(c*3, c)
"""
(head, 3, head_size, hidden_size) -> (3, head, head_size, hidden_size) -> (3 * head * head_size, hidden_size)
"""
weight2 = weight.view(c*3, c)

bias1 = bias
bias1 = bias1.view(3, num_heads, cn).permute(1, 0, 2).contiguous().view(c*3)
# bias2 = bias.view(num_heads, 3, cn).permute(1, 0, 2).contiguous().view(c*3)
bias2 = bias

qkv1 = F.linear(x, weight1, bias=None)
qkv2 = F.linear(x, weight2, bias=None)


# libai version
qkv1 = qkv1.view(b, -1, num_heads, 3*c//num_heads)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)


# timm version
qkv2 = qkv2.reshape(b, n, 3, num_heads, c//num_heads)
qkv2 = qkv2.permute(2, 0, 3, 1, 4)
q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]

print((q1==q2))     # tensor(True)
print((k1==k2))     # tensor(True)
print((v1==v2))     # tensor(True)
  • LiBai下的qkv切分
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(batch_size, -1, num_heads, 3 * head_size)
qkv = qkv.permute(0, 2, 1, 3)
qkv = flow.chunk(qkv, chunks=3, dim=-1)
  • timm下的qkv切分
qkv = self.query_key_value(hidden_states)
qkv = qkv.reshape(batch_size, num_seq, 3, num_heads, head_size)
q, k, v = qkv[0], qkv[1], qkv[2]

这一版本是正确的实现,经过测试后能得到正确的inference结果

@rentainhe rentainhe added guide and removed bug Something isn't working labels Mar 29, 2022
@xiezipeng-ML xiezipeng-ML reopened this Jul 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants