Skip to content

Commit

Permalink
fix qwen matrix dimmention alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
WAI-clear committed Nov 8, 2023
1 parent 65ec368 commit e1f9fbd
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
__all__ = [
"QWenBlock",
"QWenForCausalLM",
"QWenLMHeadModel",
"QWenPretrainedModel",
"QWenModel",
"QWenLMHead",
Expand Down Expand Up @@ -495,7 +496,7 @@ def _get_name_mappings(cls, config: QWenConfig) -> List[StateDictNameMapping]:
mapping[1] = "qwen." + mapping[1]

if config.architectures is not None:
if "QWenForCausalLM" in config.architectures:
if "QWenForCausalLM" or "QWenLMHeadModel" in config.architectures:
mappings.extend(
[
[
Expand Down Expand Up @@ -1043,3 +1044,6 @@ def forward(self, x):

output = self._norm(x.astype(paddle.float32)).astype(x.dtype)
return output * self.weight


QWenLMHeadModel = QWenForCausalLM

0 comments on commit e1f9fbd

Please sign in to comment.