We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
我用Seq2SeqTrainer对Baichuan2-7B-Base进行LoRA微调,但是很奇怪,我发现在第一次预测后,会出现OOM问题,但是Baichuan2-7B-Chat并不会。 同时,我发现Baichuan2-7B-Base的OOM问题来源于,从预测回归训练后,模型好像会二次加载,使得显存占用翻倍从而OOM。 我对比了Base和Chat的modeling.py文件,发现主要是Base中下面代码的问题:
class NormHead(nn.Module): def __init__(self, hidden_size, vocab_size, bias=False): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) self.first_flag = True def forward(self, hidden_states): if self.training: norm_weight = nn.functional.normalize(self.weight) elif self.first_flag: self.first_flag = False self.weight = nn.Parameter(nn.functional.normalize(self.weight)) norm_weight = self.weight else: norm_weight = self.weight return nn.functional.linear(hidden_states, norm_weight)
而在Chat中则是:
class NormHead(nn.Module): def __init__(self, hidden_size, vocab_size, bias=False): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) self.first_flag = True def forward(self, hidden_states): if self.training: norm_weight = nn.functional.normalize(self.weight) self.first_flag = True elif self.first_flag: self.first_flag = False self.weight.data = nn.functional.normalize(self.weight) norm_weight = self.weight else: norm_weight = self.weight return nn.functional.linear(hidden_states, norm_weight)
将Base中的替换为Chat中的NormHead后问题解决,想请问下这个原因是为什么呢?两个modeling文件是否可以互用?
The text was updated successfully, but these errors were encountered:
应该主要是没有self.first_flag = True造成的吧?Base没有这个就会造成从预测转训练的时候,进不到目标分支?
Sorry, something went wrong.
No branches or pull requests
我用Seq2SeqTrainer对Baichuan2-7B-Base进行LoRA微调,但是很奇怪,我发现在第一次预测后,会出现OOM问题,但是Baichuan2-7B-Chat并不会。
同时,我发现Baichuan2-7B-Base的OOM问题来源于,从预测回归训练后,模型好像会二次加载,使得显存占用翻倍从而OOM。
我对比了Base和Chat的modeling.py文件,发现主要是Base中下面代码的问题:
而在Chat中则是:
将Base中的替换为Chat中的NormHead后问题解决,想请问下这个原因是为什么呢?两个modeling文件是否可以互用?
The text was updated successfully, but these errors were encountered: