-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llava Onevision: add model #32673
Llava Onevision: add model #32673
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks for adding this model @zucchini-nlp! If the model has its own paper - we should add it as its own model and leverage the diff converter for the model addition if changes are needed |
Agree with @amyeroberts, which is more in line with the philosophy |
@zucchini-nlp You've hardcoded a path in modeling_llava_next.py and I suppose it's for debugging?
|
Yes, will remove that. I am making a completely new modeling folder for Onevision currently |
8c3f82c
to
380e99a
Compare
@amyeroberts Ready for review. Added a new model LLaVa-Onevision. I didn't use diff because I realized that copying everything from llava-next is not a good decision. The current code incorporates all the recent refactors we've been working on and supports I added two processing files, one for image and another for video, since I've been planning to go over video-image LLMs and separate processing into two files. And later add more utility functions to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Main comment is about the new logic to use config.text_config
src/transformers/cache_utils.py
Outdated
|
||
if hasattr(config, "text_config"): # in case of composite models | ||
config = config.text_config | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we want to do this this way: we could equally end up with several text models which might be called different things than text_config
. Instead, the cache should be created with the correct config i.e. StaticCache(config.text_config)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do it inside cache class also, and in that case handling of composite models will move to cache initialization of each cache class. I see comment from @gante about a helper function to get decoder/composite models' configs. In that case we can leave it here imo, and after @gante merged a PR modify this and other places to self.get_text_config(config)
Oops, realized it is in cache utils. Will move this to generation mixin and adjust tests if needed. Later gante will unify it in helper fn
final_layer = ( | ||
self.config.text_config.num_hidden_layers | ||
if hasattr(self.config, "text_config") | ||
else self.config.num_hidden_layers | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here - this is making strong (and possibly quite brittle) assumptions about the configs and what properties to use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, let me see if we can rely on smth else and not config, and check out the case if a encoder-decoder
config is used. Prob we'll have to handle it same way as above, via a helper function
No, seems like we need to get num-layers, so will leave it as is to be handled properly in later PRs. I think this and other generations might be failing in EncoderDecoder
models because those are not tested with GenerationMixin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK cc @gante for reference
src/transformers/models/llava_onevision/modeling_llava_onevision.py
Outdated
Show resolved
Hide resolved
src/transformers/models/llava_onevision/processing_llava_onevision.py
Outdated
Show resolved
Hide resolved
tests/generation/test_utils.py
Outdated
@@ -459,7 +459,8 @@ def test_greedy_generate_dict_outputs(self): | |||
# Retrocompatibility check | |||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) | |||
|
|||
self._check_outputs(output_generate, input_ids, model.config) | |||
config = model.config if not hasattr(model.config, "text_config") else model.config.text_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels brittle and makes a strong assumption about the structure of our configs it would be good to have @gante's view on this here.
I'd suggest moving this handling into _check_outputs
so the logic change only has to be in on place rather than every _check_outputs
call, which already has conditional logic on e.g. config.is_encoder_decoder
within the utility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are starting to have more places where we need to pull the decoder part of a config in a composite/encoder-decoder model:
- in the generation config
- in while saving a pretrained config, to ensure it doesn't contain generate flags [open PR]
- here
Note that they are all generate
-related. In this case, the shape of the generate
outputs depends on the decoder, so we need to pull the decoder config for testing purposes.
@amyeroberts I was thinking of abstracting this logic into GenerationMixin
, in a function that pulls the decoder config given a config. WDYT? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be pro! This should help encapsulate the logic without having to handle it everywhere else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, I'll open a PR after #32659 gets merged (to avoid conflicts :) )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, will port these to _check_output
until the PR is merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI opening a PR to add the generic function AND fixing _check_output
(#33212 also needs it :D )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…sion.py Co-authored-by: amyeroberts <[email protected]>
@amyeroberts should be ready! I left the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks for all the work adding!
final_layer = ( | ||
self.config.text_config.num_hidden_layers | ||
if hasattr(self.config, "text_config") | ||
else self.config.num_hidden_layers | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK cc @gante for reference
# Padding side can be in TextKwargs but is not accepted by the tokenizer | ||
_ = output_kwargs["text_kwargs"].pop("padding_side", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should find a way to handle this better as we don't want to have to add this for all the processors cc @molbap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can just remove padding_side
from TextKwargs, which I believe is already done in one of the PRs for standardizaing
Co-authored-by: amyeroberts <[email protected]>
Will need some changes in video processing as discussed here (LLaVA-VL/LLaVA-NeXT#144). Comes out the demo notebook was buggy and videos should be processed in a different way. Will work on it when I get reply from authors on what should be the correct way |
@zucchini-nlp Thanks for the update! I'll mark myself a re-requesting for review then as this will be a significant change |
@amyeroberts yes, the tests should be green now. Merging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks for all the work iterating on this!
Only thing is to add a short description of the model and resolve the failing SDPA test before merge
|
||
## Overview | ||
|
||
The LLaVA-Onevision model was proposed in [LLaVA-OneVision: Easy Visual Task Transfer](https://arxiv.org/abs/2408.03326) by <Bo Li, Yuanhan Zhang, Dong Guo, Renrui Zhang, Feng Li, Hao Zhang, Kaichen Zhang, Yanwei Li, Ziwei Liu, Chunyuan Li |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a short description of the model here? This is what will be used in the release notes. A good example is the music gen page
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
oke, will merge it now as the model works. The only thing I am struggling now is ONNX conversion which is not related to the PR |
* working version * fix copies * update * tests * update docs * codestyle * add more tests * add returns for docs * clean up * Update src/transformers/models/llava_onevision/processing_llava_onevision.py Co-authored-by: amyeroberts <[email protected]> * updates * codestyle * style * shouldn't be reversed * [run-slow] llava_onevision * [run-slow] llava_onevision * add pooling in videos * [run-slow] llava_onevision * num-logits-to-keep * [run-slow] llava_onevision * [run-slow] llava_onevision * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <[email protected]> * video matched orig impl * fix tests * chat template was modified * Update docs/source/en/model_doc/llava_onevision.md Co-authored-by: amyeroberts <[email protected]> * add morer info in the doc page --------- Co-authored-by: amyeroberts <[email protected]>
* working version * fix copies * update * tests * update docs * codestyle * add more tests * add returns for docs * clean up * Update src/transformers/models/llava_onevision/processing_llava_onevision.py Co-authored-by: amyeroberts <[email protected]> * updates * codestyle * style * shouldn't be reversed * [run-slow] llava_onevision * [run-slow] llava_onevision * add pooling in videos * [run-slow] llava_onevision * num-logits-to-keep * [run-slow] llava_onevision * [run-slow] llava_onevision * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <[email protected]> * video matched orig impl * fix tests * chat template was modified * Update docs/source/en/model_doc/llava_onevision.md Co-authored-by: amyeroberts <[email protected]> * add morer info in the doc page --------- Co-authored-by: amyeroberts <[email protected]>
hello any idea how to get a llama model working with this pr. mylesgoose/Llama-3.1-Minitron-4B-Llava-Nvidia-siglip-ov regards myles I have tried adding the similar files to the huggingface repo as per the qwen model which works. there is small issue I think because the 4b model is a stripped down 8b model with weight sizes. |
@mylesgoose I am not sure if you got the model from tuning using original repo or in another way. If you used the original repo, feel free to convert to HF style using We try to reserve GH for issues/bugs and feature requests, feel free to open a thread in the forum for any further questions 🤗 |
@zucchini-nlp no I made the model using your git repo repo llava one vision. it's a llama 4b model. and this is a feature request. I trained it. because the model is only supported of qwen type models. at present this transformers can't understand my weight are smaller. I am not sure exactly how to modify this one vision transformers to load the llama model type. can you have a look at the huggingface repo.and you will see what I mean. |
@mylesgoose I am not sure I understand you. The repo shows a trained model with a different LM backbone and has inference script with llava-vl repo code. So I guess you did training and can do generation with llava-vl repo In transformers we already support any LM backbone, and all you need is to convert the weights, which has several conversion scripts for different llava models. So this is not a feature request. In case you still believe we can help you by adding/modifying llava code, please open an issue and provide minimal script to see what you want to achieve and what output is expected from running it 🤗 |
Okay thanks @zucchini-nlp I understand now. I found that script and ran one to generate the bin file and that worked. however the script does not work on the llama model. so I will modify it tomorrow and get it to work and then return working adjusted one. I think it also does not add tokens for video etc also. or check if pad token exists prior to creating one. and created meta tensors on devices and could not use in latter step. will figure it out. thanks allot. |
This seems to have worked. @zucchini-nlp import argparse
import torch
from transformers import (
AddedToken,
AutoConfig,
AutoImageProcessor,
AutoTokenizer,
LlavaConfig,
LlavaForConditionalGeneration,
LlavaProcessor,
SiglipVisionConfig,
)
EPILOG_TXT = """Example:
python convert_llava_weights_to_hf.py \\
--text_model_id mylesgoose/Llava-Llama-3.1-Minitron-4B-Nvidia-siglip-ov \\
--vision_model_id google/siglip-so400m-patch14-384 \\
--output_hub_path mylesgoose/Llama-3.1-Minitron-4B-Llava-Nvidia-siglip-ov-hf \\
--old_state_dict_id /home/myles/LLaVA-NeXT/tmp/hf_models/llava-Llama-3.1-Minitron-4B/model_state_dict.bin
Example for creating the old state dict file with Python:
import torch
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
# Load model
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
model = LlavaLlamaForCausalLM.from_pretrained("mylesgoose/Llava-Llama-3.1-Minitron-4B-Nvidia-siglip-ov", low_cpu_mem_usage=True, **kwargs)
# Load vision tower
model.get_vision_tower().load_model()
# Save state dict
torch.save(model.state_dict(), "/home/myles/LLaVA-NeXT/tmp/hf_models/llava-Llama-3.1-Minitron-4B/model_state_dict.bin")
"""
KEYS_TO_MODIFY_MAPPING = {
"model.vision_tower.vision_tower.": "vision_tower.",
"model.vision_tower.": "vision_tower.",
".vision_resampler": "", # Not used in the original implementation
"model.mm_projector": "multi_modal_projector",
"model": "model.model",
"vision_model.model": "vision_model",
"lm_head": "language_model.lm_head",
"model.model": "language_model.model",
"multi_modal_projector.0": "multi_modal_projector.linear_1",
"multi_modal_projector.2": "multi_modal_projector.linear_2",
}
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
# Skip keys ending with .inv_freq (if present)
if key.endswith(".inv_freq"):
continue
# Modify keys based on the mapping
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
return new_state_dict
def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
torch.set_default_dtype(torch.bfloat16)
text_config = AutoConfig.from_pretrained(text_model_id)
tokenizer = AutoTokenizer.from_pretrained(text_model_id)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
tokenizer.add_tokens(AddedToken("<video>", special=True, normalized=False), special_tokens=True)
if tokenizer.pad_token_id is None:
tokenizer.add_special_tokens({"pad_token": "<|finetune_right_pad_id|>"})
image_processor = AutoImageProcessor.from_pretrained(vision_model_id)
processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
# Set vision_config based on the vision model
vision_config = SiglipVisionConfig.from_pretrained(vision_model_id)
# Adjust the vision_config to match your model's parameters
vision_config.hidden_size = 1152
vision_config.num_attention_heads = 16
vision_config.num_hidden_layers = 24
vision_config.intermediate_size = 4304
vision_config.patch_size = 14
vision_config.image_size = 384
vision_config.vision_use_head = False
config = LlavaConfig(
text_config=text_config,
vision_config=vision_config.to_dict(),
)
# Set other configuration parameters as needed
config.pad_token_id = tokenizer.pad_token_id
config.image_token_index = tokenizer.convert_tokens_to_ids("<image>")
config.video_token_index = tokenizer.convert_tokens_to_ids("<video>")
config.vision_feature_layer = -2 # Adjust if necessary
config.vision_tower_pretrained = vision_model_id
# Instantiate the model
model = LlavaForConditionalGeneration(config)
# Load the state dictionary from the provided path
state_dict_path = old_state_dict_id
state_dict = torch.load(state_dict_path, map_location="cpu")
# Convert state dict to match Hugging Face model's expected keys
state_dict = convert_state_dict_to_hf(state_dict)
# Load the state dictionary into the model
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")
# Resize token embeddings to accommodate new tokens
model.resize_token_embeddings(len(tokenizer))
# Get the updated sizes
pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
vocab_size = pre_expansion_embeddings.size(0)
new_vocab_size = model.language_model.model.embed_tokens.weight.size(0)
num_new_tokens = new_vocab_size - vocab_size
print(f"Original vocab size: {vocab_size}")
print(f"New vocab size: {new_vocab_size}")
print(f"Number of new tokens: {num_new_tokens}")
# Initialize new token embeddings only if there are new tokens
if num_new_tokens > 0:
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size(0)
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(
mu.cpu(), covariance_matrix=1e-5 * sigma.cpu()
)
new_embeddings = torch.stack(
[dist.sample() for _ in range(num_new_tokens)], dim=0
).to(pre_expansion_embeddings.device)
model.language_model.model.embed_tokens.weight.data[vocab_size:] = new_embeddings
model.language_model.lm_head.weight.data[vocab_size:] = new_embeddings
else:
print("No new tokens to add to embeddings.")
# Save the converted model and processor
model.push_to_hub(output_hub_path)
processor.push_to_hub(output_hub_path)
def main():
parser = argparse.ArgumentParser(
epilog=EPILOG_TXT,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--text_model_id",
default="mylesgoose/Llava-Llama-3.1-Minitron-4B-Nvidia-siglip-ov",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
default="google/siglip-so400m-patch14-384",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_hub_path",
default="mylesgoose/Llama-3.1-Minitron-4B-Llava-Nvidia-siglip-ov-hf",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--old_state_dict_id",
default="/home/myles/LLaVA-NeXT/tmp/hf_models/llava-Llama-3.1-Minitron-4B/model_state_dict.bin",
help="Path to the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
args = parser.parse_args()
convert_llava_llama_to_hf(
args.text_model_id,
args.vision_model_id,
args.output_hub_path,
args.old_state_dict_id
)
if __name__ == "__main__":
main() |
What does this PR do?
Adds llava onevision models (https://arxiv.org/abs/2408.03326) which are basically same as LLaVa-NeXT with have a different vision tower and image grid pinpoints