forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Chameleon: add model (huggingface#31534)
* Chameleon model integration Co-authored-by: Jacob Kahn <[email protected]> Co-authored-by: Leonid Shamis <[email protected]> * fix 7B, again. mask away image tokens * Apply suggestions from code review Co-authored-by: Arthur <[email protected]> * remove pretrained_config_map * make fixup passing up to utils/check_config_docstrings.py; vqgan moved to the modeling file * remove tokenizer (use llama's); remove codechameleon tests * a few copied from statements and minor changes * copied from in ChameleonModel * some copies in ChameleonForCausalLM * a few more copies * VQModel moved to ChameleonModel (as opposed to being in the processor) * ChameleonProcessor ready * Fix chameleon weights convert * update conversion script * clean-up processing * update modeling a bit * update * update (throws error...) * correct conversion ready * fix tests * fix docs * docs * ve swin norm * fix device for vocab map * add normalization * update * update script with rope rotations * final fix on model conversion * add slow tests * more info in docs * fix repo consistency tests * fix repo tests * fix-copies * hope this will make CI happy * fix for 30b model * Update docs/source/en/index.md Co-authored-by: amyeroberts <[email protected]> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/chameleon/modeling_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <[email protected]> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <[email protected]> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <[email protected]> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/auto/configuration_auto.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/chameleon/image_processing_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/chameleon/image_processing_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/chameleon/image_processing_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/chameleon/image_processing_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/chameleon/modeling_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/chameleon/processing_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/chameleon/processing_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/chameleon/test_modeling_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/chameleon/test_modeling_chameleon.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/chameleon/test_modeling_chameleon.py Co-authored-by: amyeroberts <[email protected]> * address comments * remove assertion in conversion script * add image processor test * not copied * port changes for qk layernorm * fix-copies * read token decorator for tests * [run-slow] chameleon * one more read-token * address some comments * qk norm changes * tests and repo check * moved rope permutations to conversion, YAY! * fix past kv check * docs * layernorm done! * let's be consistent in naming * fix slow tests * weird thing with slow CI, but let's see * once more try * remove past-kv as tuple following llama * ignore * style --------- Co-authored-by: Pablo Montalvo <[email protected]> Co-authored-by: ArthurZucker <[email protected]> Co-authored-by: jacobkahn <[email protected]> Co-authored-by: Leonid Shamis <[email protected]> Co-authored-by: Leonid Shamis <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: amyeroberts <[email protected]>
- Loading branch information
1 parent
5ecd65b
commit 108c39b
Showing
24 changed files
with
3,950 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# Chameleon | ||
|
||
## Overview | ||
|
||
The Chameleon model was proposed in [Chameleon: Mixed-Modal Early-Fusion Foundation Models | ||
](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet. | ||
|
||
|
||
The abstract from the paper is the following: | ||
|
||
*We present Chameleon, a family of early-fusion token-based mixed-modal models capable of understanding and generating images and text in any arbitrary sequence. We outline a stable training | ||
approach from inception, an alignment recipe, and an architectural parameterization tailored for the | ||
early-fusion, token-based, mixed-modal setting. The models are evaluated on a comprehensive range | ||
of tasks, including visual question answering, image captioning, text generation, image generation, and | ||
long-form mixed modal generation. Chameleon demonstrates broad and general capabilities, including | ||
state-of-the-art performance in image captioning tasks, outperforms Llama-2 in text-only tasks while | ||
being competitive with models such as Mixtral 8x7B and Gemini-Pro, and performs non-trivial image | ||
generation, all in a single model. It also matches or exceeds the performance of much larger models, | ||
including Gemini Pro and GPT-4V, according to human judgments on a new long-form mixed-modal | ||
generation evaluation, where either the prompt or outputs contain mixed sequences of both images and | ||
text. Chameleon marks a significant step forward in a unified modeling of full multimodal documents* | ||
|
||
|
||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/chameleon_arch.png" | ||
alt="drawing" width="600"/> | ||
|
||
<small> Chameleon incorporates a vector quantizer module to transform images into discrete tokens. That also enables image geenration using an auto-regressive transformer. Taken from the <a href="https://arxiv.org/abs/2405.09818v1">original paper.</a> </small> | ||
|
||
This model was contributed by [joaogante](https://huggingface.co/joaogante) and [RaushanTurganbay](https://huggingface.co/RaushanTurganbay). | ||
The original code can be found [here](https://github.com/facebookresearch/chameleon). | ||
|
||
|
||
## Usage tips | ||
|
||
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating. | ||
|
||
- Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question. | ||
|
||
- Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor. | ||
|
||
> [!NOTE] | ||
> Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: `<reserved08707>`. | ||
## Usage example | ||
|
||
### Single image inference | ||
|
||
Here's how to load the model and perform inference in half-precision (`torch.float16`): | ||
|
||
```python | ||
from transformers import ChameleonProcessor, ChameleonForCausalLM | ||
import torch | ||
from PIL import Image | ||
import requests | ||
|
||
processor = ChameleonProcessor.from_pretrained("meta-chameleon") | ||
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto") | ||
|
||
# prepare image and text prompt | ||
url = "https://bjiujitsu.com/wp-content/uploads/2021/01/jiu_jitsu_belt_white_1.jpg" | ||
image = Image.open(requests.get(url, stream=True).raw) | ||
prompt = "What color is the belt in this image?<image>" | ||
|
||
inputs = processor(prompt, image, return_tensors="pt").to(model.device) | ||
|
||
# autoregressively complete prompt | ||
output = model.generate(**inputs, max_new_tokens=50) | ||
print(processor.decode(output[0], skip_special_tokens=True)) | ||
``` | ||
|
||
### Multi image inference | ||
|
||
Chameleon can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it: | ||
|
||
```python | ||
from transformers import ChameleonProcessor, ChameleonForCausalLM | ||
import torch | ||
from PIL import Image | ||
import requests | ||
|
||
processor = ChameleonProcessor.from_pretrained("meta-chameleon") | ||
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto") | ||
|
||
# Get three different images | ||
url = "https://www.ilankelman.org/stopsigns/australia.jpg" | ||
image_stop = Image.open(requests.get(url, stream=True).raw) | ||
|
||
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
image_cats = Image.open(requests.get(url, stream=True).raw) | ||
|
||
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" | ||
image_snowman = Image.open(requests.get(url, stream=True).raw) | ||
|
||
# Prepare a batched prompt, where the first one is a multi-image prompt and the second is not | ||
prompts = [ | ||
"What do these images have in common?<image><image>", | ||
"<image>What is shown in this image?" | ||
] | ||
|
||
# We can simply feed images in the order they have to be used in the text prompt | ||
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens | ||
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device) | ||
|
||
# Generate | ||
generate_ids = model.generate(**inputs, max_new_tokens=50) | ||
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | ||
``` | ||
|
||
## Model optimization | ||
|
||
### Quantization using Bitsandbytes | ||
|
||
The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: | ||
|
||
```python | ||
from transformers import ChameleonForCausalLM, BitsAndBytesConfig | ||
|
||
# specify how to quantize the model | ||
quantization_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_quant_type="nf4", | ||
bnb_4bit_compute_dtype=torch.float16, | ||
) | ||
|
||
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", quantization_config=quantization_config, device_map="auto") | ||
``` | ||
|
||
### Use Flash-Attention 2 and SDPA to further speed-up generation | ||
|
||
The models supports both, Flash-Attention 2 and PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) which can be enables for optimization. SDPA is the default options when you load the model, If you want to switch for Flash Attention 2, first make sure to install flash-attn. Refer to the [original repository](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with: | ||
|
||
```python | ||
from transformers import ChameleonForCausalLM | ||
|
||
model = ChameleonForCausalLM.from_pretrained( | ||
model_id, | ||
torch_dtype=torch.float16, | ||
low_cpu_mem_usage=True, | ||
attn_implementation="flash_attention_2" | ||
).to(0) | ||
``` | ||
|
||
## ChameleonConfig | ||
|
||
[[autodoc]] ChameleonConfig | ||
|
||
## ChameleonVQVAEConfig | ||
|
||
[[autodoc]] ChameleonVQVAEConfig | ||
|
||
## ChameleonProcessor | ||
|
||
[[autodoc]] ChameleonProcessor | ||
|
||
## ChameleonImageProcessor | ||
|
||
[[autodoc]] ChameleonImageProcessor | ||
- preprocess | ||
|
||
## ChameleonVQVAE | ||
|
||
[[autodoc]] ChameleonVQVAE | ||
- forward | ||
|
||
## ChameleonModel | ||
|
||
[[autodoc]] ChameleonModel | ||
- forward | ||
|
||
## ChameleonForCausalLM | ||
|
||
[[autodoc]] ChameleonForCausalLM | ||
- forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ | |
byt5, | ||
camembert, | ||
canine, | ||
chameleon, | ||
chinese_clip, | ||
clap, | ||
clip, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.