-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable input_embeds for ChatGLM / ChatGLMForConditionalGeneration #5775
Conversation
Fix typo and minor bugs to enable the input_embeds input rather than input_ids
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
else: | ||
raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
||
if inputs_embeds is None: | ||
inputs_embeds = self.word_embeddings(input_ids) | ||
inputs_embeds = inputs_embeds.transpose([1, 0, 2]) | ||
inputs_embeds = inputs_embeds.transpose([1, 0, 2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的一些小坑是在于input_ids是作为默认输出,在export和inference逻辑是默认使用input_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其它transformer模型(比如ERNIE)对input_ids都有默认值None。同时PretrainedModel实例也都会检查,要求input_ids和embeds不同时为None,不同时有值。所以这里应该还好
Codecov Report
@@ Coverage Diff @@
## develop #5775 +/- ##
===========================================
- Coverage 61.60% 61.57% -0.03%
===========================================
Files 489 489
Lines 68500 68540 +40
===========================================
+ Hits 42197 42205 +8
- Misses 26303 26335 +32
|
@@ -826,7 +826,7 @@ def update_model_kwargs_for_generation( | |||
|
|||
def forward( | |||
self, | |||
input_ids, | |||
input_ids=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 input_ids
默认设为 None
时,需要修改下 ChatGLMModel
中 position_ids is None
和 attention_mask is None
的分支逻辑,因为这两个分支都依赖 input_ids
。
1)可支持为None
的,在 input_ids is None
时将 input_ids
相关的参数改为从 input_embeds
获取。
2) 不支持为 None
的,需要显式抛出异常,说明原因。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照逻辑2)修改了,要求input_ids与attention_mask / position_ids不同时为None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line693
Fix typo and minor bugs to enable the input_embeds input rather than input_ids
PR types
Bug fixes
PR changes
Models / APIs
Description
See changelog / Commit