-
-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix][Model] Add base class for vision-language models #4809
Conversation
ClassVar
to indicate vision modelsClassVar
to indicate vision models and improve error handling when incorrect image_feature_size
is passed
ClassVar
to indicate vision models and improve error handling when incorrect image_feature_size
is passedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the change looks good to me. cc @simon-mo
@@ -172,7 +174,7 @@ def forward( | |||
image_features = image_input | |||
vision_embeddings = self.multi_modal_projector(image_features) | |||
inputs_embeds = self.language_model.get_input_embeddings(input_ids) | |||
_merge_vision_embeddings( | |||
inputs_embeds = _merge_vision_embeddings( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it is just to make explicit the fact that inputs_embeds
is modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for the fix!
@simon-mo The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me retry one more time
@DarkLight1337 @ywang96 @rkooo567 Is this PR ready for merge? |
Yes. |
This PR adds a base class
VLMBase
to avoid importingLlavaForConditionalGeneration
invllm/model_executor/model_loader/loader.py
, thus solving #4807.Along the way, I have also ported the improved error handling logic regarding
image_feature_size
for LLaVA model.FIX #4807