Skip to content

Commit

Permalink
LLaVaNeXT: pad on right if training (#32134)
Browse files Browse the repository at this point in the history
* pad on right if training

* docs

* add tests
  • Loading branch information
zucchini-nlp committed Jul 23, 2024
1 parent 251a240 commit 3aefb4e
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 10 deletions.
7 changes: 7 additions & 0 deletions docs/source/en/model_doc/llava-next-video.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tre

- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.

<Tip warning={true}>

- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding".

</Tip>


- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use tokenizer's `apply_chat_template` to format your prompts correctly. Below is an example of how to do that.

We will use [LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) and a conversation history of videos and images. Each content field has to be a list of dicts, as follows:
Expand Down
7 changes: 7 additions & 0 deletions docs/source/en/model_doc/llava_next.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/

- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.

<Tip warning={true}>

- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding".

</Tip>


- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities. Below is an example of how to do that and the list of formats accepted by each checkpoint.

We will use [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-hf/llava-v1.6-mistral-7b-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ def _merge_input_ids_with_image_features(
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)

left_padding = True
if batch_size > 1:
left_padding = True if not self.training else False
if batch_size > 1 and not self.training:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,8 @@ def _merge_input_ids_with_image_features(
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)

left_padding = True
if batch_size > 1:
left_padding = True if not self.training else False
if batch_size > 1 and not self.training:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:
Expand Down
38 changes: 35 additions & 3 deletions tests/models/llava_next/test_modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
self.batch_size = 3
self.num_channels = 3
self.image_size = 30
self.encoder_seq_length = 341
self.encoder_seq_length = 342
self.image_grid_pinpoints = [[32, 32]]

def get_config(self):
Expand Down Expand Up @@ -156,9 +156,7 @@ def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
# make attention mask left-padded to avoid issues with "model has no attribute padding_side"
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 images let's make sure we pass in 3 image tokens
input_ids[:, 1] = config.image_token_index
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
Expand Down Expand Up @@ -473,3 +471,37 @@ def test_small_model_integration_test_batch_matches_single(self):
self.processor.decode(output_batched[0], skip_special_tokens=True),
self.processor.decode(output_single[0], skip_special_tokens=True),
)

@slow
@require_bitsandbytes
def test_padding_side_when_merging_inputs(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
cats_image = Image.open(requests.get(url, stream=True).raw)
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)

inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)

# model is in eval mode by default so we should get pad on the left side
# we can check the first hidden-states (aka inputs embeds)
# the first element was lo-res image and we expect the first 1414 tokens to be all pads
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1414, ...] == 0).all().item())

# otherwise padding is on the right side, so it's last 1414 tokens
self.processor.padding_side = "right"
inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)

model.train()
with torch.no_grad():
output_train = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item())
40 changes: 37 additions & 3 deletions tests/models/llava_next_video/test_modeling_llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
self.batch_size = 3
self.num_channels = 3
self.image_size = 30
self.encoder_seq_length = 468
self.encoder_seq_length = 469
self.image_grid_pinpoints = [[32, 32]]

def get_config(self):
Expand Down Expand Up @@ -166,9 +166,7 @@ def prepare_config_and_inputs(self):
def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs()
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
# make attention mask left-padded to avoid issues with "model has no attribute padding_side"
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 images and videos let's make sure we pass in 3 special tokens
input_ids[:, 1] = config.image_token_index
input_ids[:, 2] = config.video_token_index
Expand Down Expand Up @@ -453,3 +451,39 @@ def test_small_model_integration_test_batch_matches_single(self):
self.processor.decode(output_batched[0], skip_special_tokens=True),
self.processor.decode(output_single[0], skip_special_tokens=True),
)

@slow
@require_bitsandbytes
def test_padding_side_when_merging_inputs(self):
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
)

inputs_batched = self.processor(
[self.prompt_video, self.prompt_image],
images=[self.image],
videos=[self.video],
return_tensors="pt",
padding=True,
).to(torch_device)

# model is in eval mode by default so we should get pad on the left side
# we can check the first hidden-states (aka inputs embeds)
# the first element was lo-res image and we expect the first 1482 tokens to be all pads
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item())

# otherwise padding is on the right side, so it's last 1482 tokens
self.processor.padding_side = "right"
inputs_batched = self.processor(
[self.prompt_video, self.prompt_image],
images=[self.image],
videos=[self.video],
return_tensors="pt",
padding=True,
).to(torch_device)

model.train()
with torch.no_grad():
output_train = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item())

0 comments on commit 3aefb4e

Please sign in to comment.