-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable input_embeds for ChatGLM / ChatGLMForConditionalGeneration #5775
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -501,13 +501,13 @@ def forward( | |
elif input_ids is not None: | ||
batch_size, seq_length = input_ids.shape[:2] | ||
elif inputs_embeds is not None: | ||
batch_size, seq_length, _ = inputs_embeds.shape[:2] | ||
batch_size, seq_length = inputs_embeds.shape[:2] | ||
else: | ||
raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
||
if inputs_embeds is None: | ||
inputs_embeds = self.word_embeddings(input_ids) | ||
inputs_embeds = inputs_embeds.transpose([1, 0, 2]) | ||
inputs_embeds = inputs_embeds.transpose([1, 0, 2]) | ||
|
||
if cache is None: | ||
if self.config.pre_seq_len is not None: | ||
|
@@ -690,6 +690,10 @@ def forward( | |
use_cache: bool = None, | ||
return_dict: bool = None, | ||
): | ||
if input_ids is None: | ||
assert position_ids is not None, "`position_ids` must be explicitly specified when input_ids is None." | ||
assert attention_mask is not None, "`attention_mask` must be explicitly specified when input_ids is None." | ||
|
||
if attention_mask is None: | ||
attention_mask = self.get_masks(input_ids) | ||
|
||
|
@@ -826,7 +830,7 @@ def update_model_kwargs_for_generation( | |
|
||
def forward( | ||
self, | ||
input_ids, | ||
input_ids=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里 1)可支持为 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 按照逻辑2)修改了,要求input_ids与attention_mask / position_ids不同时为None There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line693 |
||
position_ids=None, | ||
attention_mask=None, | ||
cache=None, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的一些小坑是在于input_ids是作为默认输出,在export和inference逻辑是默认使用input_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其它transformer模型(比如ERNIE)对input_ids都有默认值None。同时PretrainedModel实例也都会检查,要求input_ids和embeds不同时为None,不同时有值。所以这里应该还好