diff --git a/docs/source/en/model_doc/llava-next-video.md b/docs/source/en/model_doc/llava-next-video.md index 88e41efc29c87c..48e50f950621e8 100644 --- a/docs/source/en/model_doc/llava-next-video.md +++ b/docs/source/en/model_doc/llava-next-video.md @@ -43,6 +43,13 @@ The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tre - We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating. + + +- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding". + + + + - Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use tokenizer's `apply_chat_template` to format your prompts correctly. Below is an example of how to do that. We will use [LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) and a conversation history of videos and images. Each content field has to be a list of dicts, as follows: diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index 9e7caa37d7b9bc..0c25ed32db5ab3 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -46,6 +46,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ - We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating. + + +- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding". + + + + - Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities. Below is an example of how to do that and the list of formats accepted by each checkpoint. We will use [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-hf/llava-v1.6-mistral-7b-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows: diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 5b897b817330b7..ad76561df54fd7 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -518,8 +518,8 @@ def _merge_input_ids_with_image_features( _left_padding = torch.any(attention_mask[:, 0] == 0) _right_padding = torch.any(attention_mask[:, -1] == 0) - left_padding = True - if batch_size > 1: + left_padding = True if not self.training else False + if batch_size > 1 and not self.training: if _left_padding and not _right_padding: left_padding = True elif not _left_padding and _right_padding: diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index f2ccb99e618753..e3264dfd91e1a1 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -562,8 +562,8 @@ def _merge_input_ids_with_image_features( _left_padding = torch.any(attention_mask[:, 0] == 0) _right_padding = torch.any(attention_mask[:, -1] == 0) - left_padding = True - if batch_size > 1: + left_padding = True if not self.training else False + if batch_size > 1 and not self.training: if _left_padding and not _right_padding: left_padding = True elif not _left_padding and _right_padding: diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 69794a85d9fe39..70d91002a91bc3 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -123,7 +123,7 @@ def __init__( self.batch_size = 3 self.num_channels = 3 self.image_size = 30 - self.encoder_seq_length = 341 + self.encoder_seq_length = 342 self.image_grid_pinpoints = [[32, 32]] def get_config(self): @@ -156,9 +156,7 @@ def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2 - # make attention mask left-padded to avoid issues with "model has no attribute padding_side" attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) - attention_mask[:, :1] = 0 # we are giving 3 images let's make sure we pass in 3 image tokens input_ids[:, 1] = config.image_token_index labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) @@ -473,3 +471,37 @@ def test_small_model_integration_test_batch_matches_single(self): self.processor.decode(output_batched[0], skip_special_tokens=True), self.processor.decode(output_single[0], skip_special_tokens=True), ) + + @slow + @require_bitsandbytes + def test_padding_side_when_merging_inputs(self): + model = LlavaNextForConditionalGeneration.from_pretrained( + "llava-hf/llava-v1.6-mistral-7b-hf", + load_in_4bit=True, + ) + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e" + cats_image = Image.open(requests.get(url, stream=True).raw) + lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) + + inputs_batched = self.processor( + [self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True + ).to(torch_device) + + # model is in eval mode by default so we should get pad on the left side + # we can check the first hidden-states (aka inputs embeds) + # the first element was lo-res image and we expect the first 1414 tokens to be all pads + output_eval = model(**inputs_batched, output_hidden_states=True) + self.assertTrue((output_eval.hidden_states[0][0, :1414, ...] == 0).all().item()) + + # otherwise padding is on the right side, so it's last 1414 tokens + self.processor.padding_side = "right" + inputs_batched = self.processor( + [self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True + ).to(torch_device) + + model.train() + with torch.no_grad(): + output_train = model(**inputs_batched, output_hidden_states=True) + self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item()) diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index afe3062fb50e0e..9ba7ef869ddf00 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -124,7 +124,7 @@ def __init__( self.batch_size = 3 self.num_channels = 3 self.image_size = 30 - self.encoder_seq_length = 468 + self.encoder_seq_length = 469 self.image_grid_pinpoints = [[32, 32]] def get_config(self): @@ -166,9 +166,7 @@ def prepare_config_and_inputs(self): def prepare_config_and_inputs_for_common(self): config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs() input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2 - # make attention mask left-padded to avoid issues with "model has no attribute padding_side" attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) - attention_mask[:, :1] = 0 # we are giving 3 images and videos let's make sure we pass in 3 special tokens input_ids[:, 1] = config.image_token_index input_ids[:, 2] = config.video_token_index @@ -453,3 +451,39 @@ def test_small_model_integration_test_batch_matches_single(self): self.processor.decode(output_batched[0], skip_special_tokens=True), self.processor.decode(output_single[0], skip_special_tokens=True), ) + + @slow + @require_bitsandbytes + def test_padding_side_when_merging_inputs(self): + model = LlavaNextVideoForConditionalGeneration.from_pretrained( + "llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True + ) + + inputs_batched = self.processor( + [self.prompt_video, self.prompt_image], + images=[self.image], + videos=[self.video], + return_tensors="pt", + padding=True, + ).to(torch_device) + + # model is in eval mode by default so we should get pad on the left side + # we can check the first hidden-states (aka inputs embeds) + # the first element was lo-res image and we expect the first 1482 tokens to be all pads + output_eval = model(**inputs_batched, output_hidden_states=True) + self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item()) + + # otherwise padding is on the right side, so it's last 1482 tokens + self.processor.padding_side = "right" + inputs_batched = self.processor( + [self.prompt_video, self.prompt_image], + images=[self.image], + videos=[self.video], + return_tensors="pt", + padding=True, + ).to(torch_device) + + model.train() + with torch.no_grad(): + output_train = model(**inputs_batched, output_hidden_states=True) + self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item())