-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
About Load HuggingFace Bert #205
Comments
关于 layernorm 位置的问题可以快速回复一下 我们参考的是 megatron 的代码实现,关于残差的位置在 megatron lm 的 paper 里面有写这样一段话
所以 libai 里面的 TransformerLayer 的位置和原始的 bert 是有所不同的. |
LayerNorm位置确实是没问题的,是正常运算的。 |
关于 qkv 计算的部分,我们也是参考下面 megatron 的代码,不过 libai 里面没有把 sequence 放到最前面,对应起来流程就是
|
这个issue特别好,感觉可以单独整理出来作为一个常见问题模块,或者是Advanced Tutorials |
记录了一下两种qkv的计算方法产生不同sbp的问题, @CPFLAME |
我们推导了一下,发现对齐 huggingface 的写法会导致之前推导的 sbp 出现问题,因为 huggingface 的写法是先做 chunk,而且 chunk 的维度刚好的 sbp.split,这样切完中间隐含了一次通信开销,所以我们觉得这样做可能会带了更多别的问题,你考虑用之前开杰提供的方案试试呢。 |
好的星宇,我看怎么可以正确加载权重后能够对齐 |
不改变模型,改变wieght的加载能得到相同的结果由于我们已经证明了libai的qkv计算的正确性(换成huggingface的qkv计算导致模型并行时sbp会出问题,目前的解决办法只有直接进行to_global来解决这个问题,而且不知道会不会造成别的问题,也就是说libai中的整套模型的sbp方案是配好的,换成别的计算方式有问题),所以这里考虑用不同的weight加载方式。 两种qkv计算方式: # LiBai中的qkv计算方式:
# query_key_value:[batch_size, seq_len, 3*hidden_size]
query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size) #(a)
query_key_value = query_key_value.permute(0, 2, 1, 3) #(b)
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1) #(c)
# huggingface中的qkv计算方式:
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)
query = query.view(query.size(0), query.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
key = key.view(key.size(0), key.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
value = value.view(value.size(0), value.size(1), self.num_heads, -1).permute(0, 2, 1, 3) 首先解释一下为什么 LiBai 的
# q,k,v重叠
flow.arange(1, 2305).view(12, 3*64)
tensor([[ 1, 2, 3, ..., 190, 191, 192],
[ 193, 194, 195, ..., 382, 383, 384],
[ 385, 386, 387, ..., 574, 575, 576],
...,
[1729, 1730, 1731, ..., 1918, 1919, 1920],
[1921, 1922, 1923, ..., 2110, 2111, 2112],
[2113, 2114, 2115, ..., 2302, 2303, 2304]]) 解决思路:
import torch
import torch.nn.functional as F
bsz = 32
seq_len = 5
num_heads = 12
head_size = 64
hidden_size = num_heads*head_size
x = torch.rand(bsz, seq_len, hidden_size)
weight = torch.rand(hidden_size*3, hidden_size)
# bias = torch.rand(2304)
# my method for weight------------------------------
weight1 = weight.view([num_heads, 3, head_size, hidden_size])
weight_q, weight_k, weight_v = weight1.chunk(chunks=3, dim=0) # [4, 3, head_size, hidden_size]
weight_q = weight_q.view(-1, head_size, hidden_size) # [12, head_size, hidden_size]
weight_k = weight_k.view(-1, head_size, hidden_size)
weight_v = weight_v.view(-1, head_size, hidden_size)
weight_q = weight_q.unsqueeze(1)
weight_k = weight_k.unsqueeze(1)
weight_v = weight_v.unsqueeze(1)
weight1 = torch.cat([weight_q, weight_k, weight_v], dim=1) # [12*head_size, 3, hidden_size]
weight1 = weight1.view(-1, hidden_size)
# my method for weight end-----------------------------------------------------
weight2 = weight
qkv1 = F.linear(x, weight1, bias=None)
qkv2 = F.linear(x, weight2, bias=None)
# libai------------------------------------------
qkv1 = qkv1.view(bsz, seq_len, num_heads, 3*head_size)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)
# huggingface------------------------------------------
q2, k2, v2 = torch.chunk(qkv2, chunks=3, dim=-1)
q2 = q2.view(q2.size(0), q2.size(1), num_heads, -1).transpose(1,2)
k2 = k2.view(k2.size(0), k2.size(1), num_heads, -1).transpose(1,2)
v2 = v2.view(v2.size(0), v2.size(1), num_heads, -1).transpose(1,2)
print((q1==q2).all()) # tensor(True)
print((k1==k2).all()) # tensor(True)
print((v1==v2).all()) # tensor(True)
|
bias解决方案
import torch
import torch.nn.functional as F
import pdb
bsz = 32
seq_len = 5
num_heads = 12
head_size = 64
hidden_size = num_heads*head_size
x = torch.rand(bsz, seq_len, hidden_size)
weight = torch.rand(hidden_size*3, hidden_size)
bias = torch.rand(2304)
# my method for weight------------------------------
weight1 = weight.view([num_heads, 3, head_size, hidden_size])
weight_temp = weight1
weight_q, weight_k, weight_v = weight1.chunk(chunks=3, dim=0) # [4, 3, head_size, hidden_size]
weight_q = weight_q.view(-1, head_size, hidden_size) # [12, head_size, hidden_size]
weight_k = weight_k.view(-1, head_size, hidden_size)
weight_v = weight_v.view(-1, head_size, hidden_size)
weight_q = weight_q.unsqueeze(1)
weight_k = weight_k.unsqueeze(1)
weight_v = weight_v.unsqueeze(1)
weight1 = torch.cat([weight_q, weight_k, weight_v], dim=1) # [12*head_size, 3, hidden_size]
weight1 = weight1.view(-1, hidden_size)
# my method for weight end-----------------------------------------------------
weight2 = weight
# --------------convert bias-------------------------------
bias_ = bias.view(num_heads, 3, head_size)
bias_q, bias_k, bias_v = bias_.chunk(3, dim=0)
bias_q = bias_q.view(-1, head_size).unsqueeze(1)
bias_k = bias_k.view(-1, head_size).unsqueeze(1)
bias_v = bias_v.view(-1, head_size).unsqueeze(1)
bias1 = torch.cat([bias_q, bias_k, bias_v], dim=1).view(-1)
# -----------------------------------------------------------
qkv1 = F.linear(x, weight1, bias=bias1) # 2304, 768
qkv2 = F.linear(x, weight2, bias=bias)
# pdb.set_trace()
# libai------------------------------------------
qkv1 = qkv1.view(bsz, seq_len, num_heads, 3*head_size)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)
# huggingface------------------------------------------
q2, k2, v2 = torch.chunk(qkv2, chunks=3, dim=-1)
q2 = q2.view(q2.size(0), q2.size(1), num_heads, -1).transpose(1,2)
k2 = k2.view(k2.size(0), k2.size(1), num_heads, -1).transpose(1,2)
v2 = v2.view(v2.size(0), v2.size(1), num_heads, -1).transpose(1,2)
print((q1==q2).all()) # tensor(True)
print((k1==k2).all()) # tensor(True)
print((v1==v2).all()) # tensor(True) 整理后的代码import torch
import torch.nn.functional as F
bsz = 32
seq_len = 5
num_heads = 12
head_size = 64
hidden_size = num_heads*head_size
x = torch.rand(bsz, seq_len, hidden_size)
weight = torch.rand(hidden_size*3, hidden_size)
bias = torch.rand(2304)
# convert weight and bias
weight1 = weight.view([num_heads, 3, head_size, hidden_size])
weight_q, weight_k, weight_v = weight1.chunk(chunks=3, dim=0)
weight_q = weight_q.view(-1, head_size, hidden_size).unsqueeze(1)
weight_k = weight_k.view(-1, head_size, hidden_size).unsqueeze(1)
weight_v = weight_v.view(-1, head_size, hidden_size).unsqueeze(1)
weight1 = torch.cat([weight_q, weight_k, weight_v], dim=1).view(-1, hidden_size)
bias_ = bias.view(num_heads, 3, head_size)
bias_q, bias_k, bias_v = bias_.chunk(3, dim=0)
bias_q = bias_q.view(-1, head_size).unsqueeze(1)
bias_k = bias_k.view(-1, head_size).unsqueeze(1)
bias_v = bias_v.view(-1, head_size).unsqueeze(1)
bias1 = torch.cat([bias_q, bias_k, bias_v], dim=1).view(-1)
weight2 = weight
bias2 = bias
qkv1 = F.linear(x, weight1, bias=bias1)
qkv2 = F.linear(x, weight2, bias=bias2)
# libai------------------------------------------
qkv1 = qkv1.view(bsz, seq_len, num_heads, 3*head_size)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)
# huggingface------------------------------------------
q2, k2, v2 = torch.chunk(qkv2, chunks=3, dim=-1)
q2 = q2.view(q2.size(0), q2.size(1), num_heads, -1).transpose(1,2)
k2 = k2.view(k2.size(0), k2.size(1), num_heads, -1).transpose(1,2)
v2 = v2.view(v2.size(0), v2.size(1), num_heads, -1).transpose(1,2)
print((q1==q2).all()) # tensor(True)
print((k1==k2).all()) # tensor(True)
print((v1==v2).all()) # tensor(True) |
bert的load_pretrain_weight后输出对齐了 import oneflow as flow
import libai
from libai.models import build_model
from libai.config import LazyCall
from load_huggingface_weight import load_huggingface_bert
from libai.utils import distributed as dist
import transformers
import torch
import numpy as np
input_ids = [[101, 1962, 2110, 739, 999, 1, 2, 3, 102]]
mask = [[1]*len(input_ids)]
# libai result
cfg = dict(
vocab_size=21128,
hidden_size=768,
hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
num_tokentypes=2,
add_pooling_layer=True,
initializer_range=0.02,
layernorm_eps=1e-12,
bias_gelu_fusion=False, #
bias_dropout_fusion=False,#
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,#
add_binary_head=True,
amp_enabled=False,
apply_residual_post_layernorm=True
)
bert_lib = build_model(LazyCall(libai.models.BertModel)(cfg=cfg))
load_huggingface_bert(bert_lib, './pretrain/pytorch_model.bin', cfg['hidden_size'], cfg['num_attention_heads'])
input_of = flow.tensor(input_ids, dtype=flow.long, sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]), placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),)
mask_of = flow.tensor(mask, dtype=flow.long, sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]), placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),)
bert_lib.eval()
last_hidden_state_of, pooler_output_of = bert_lib(input_of, mask_of)
# huggingface result
bert_hug = transformers.BertModel.from_pretrained('./pretrain')
bert_hug.eval()
input_pt = torch.tensor(input_ids)
mask_pt = torch.tensor(mask)
last_hidden_state_pt = bert_hug(input_pt, mask_pt).last_hidden_state
res1 = last_hidden_state_of.detach().numpy()
res2 = last_hidden_state_pt.detach().numpy()
print(res1.sum())
print(res2.sum())
|
用LiBai的Bert加载huggingface的权重对齐输出发现的一些问题,经过修改后可以与hugigngface输出对齐
参数结构对比,可以先看最下面两个库中
Bert
的参数结构:embedding
部分和huggingface的没问题。LayerNorm
层,我们LiBai的LayerNorm
层放在每一结构的输入位置,huggingface的是放在每一结构的输出位置,也是没问题的,只需要加载huggingface权重时加载其上一层结构的LayerNorm
即可。Linear
层的地方,权重都进行permute(1,0)
就可以。LiBai的Bert与huggingface的Bert内部逻辑计算上不同,导致输出不对齐:
MultiheadAttention
中有两行代码导致这部分的输出与huggingface没法对齐,下面这两种计算方法得到的q、k、v是不一样的:TransformerLayer
内部计算逻辑和huggingface的有些部分不一样,这里的不同同样导致了LiBai的输出无法与huggingface对齐:TransformerLayer
中,也是计算逻辑不同导致输出不一致:Bert
中的bias_gelu_fusion、bias_dropout_fusion、apply_query_key_layer_scaling
设置为False
,然后我写了一个加载huggingface预训练模型的函数,加载之后LiBai的Bert
使用huggingface的权重可以得到与huggingface的Bert
一样的输出(设置相同的一句话作为输入)。先看LiBai中的Bert参数结构
再看一下huggingface的参数结构
The text was updated successfully, but these errors were encountered: