Skip to content

Commit

Permalink
Fix LLaVA-NeXT input processor and cleanup code
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Jun 5, 2024
1 parent 9bc5fcc commit 29c3bb3
Showing 1 changed file with 82 additions and 52 deletions.
134 changes: 82 additions & 52 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@
_cached_get_image_processor = lru_cache(get_image_processor)


def _get_dummy_seq_data(
*,
seq_len: int,
image_token_id: int,
image_feature_size: int,
) -> SequenceData:
# NOTE: We assume that <image> token is repeated `image_feature_size` times
# and then concatenated with the text prompt
# TODO: Enable other ways of inserting the image into the prompt

token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)

return SequenceData(token_ids)


def _get_clip_num_patches(hf_config: CLIPVisionConfig) -> int:
image_size = hf_config.image_size
patch_size = hf_config.patch_size
Expand Down Expand Up @@ -75,25 +59,28 @@ def _get_llava_next_num_unpadded_features(
return (unpadded_features, newline_features)


def _get_llava_next_image_feature_size(hf_config: LlavaNextConfig) -> int:
def _get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config

if isinstance(vision_config, CLIPVisionConfig):
num_patches = _get_clip_num_patches(vision_config)
base_feature_size = num_patches * num_patches

# Results in the max possible feature size
dummy_height, dummy_width = 448, 448
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(dummy_height, dummy_width),
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)

(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_next_num_unpadded_features(dummy_height, dummy_width,
) = _get_llava_next_num_unpadded_features(input_height, input_width,
num_patches,
num_patch_height,
num_patch_width)
Expand All @@ -108,10 +95,8 @@ class DummyImageDataFactories:
"""Contains factories for dummy image data factories."""

@classmethod
def _dummy_data_for_clip(
def _dummy_seq_data_for_clip(
cls,
model_config: ModelConfig,
multimodal_config: VisionLanguageConfig,
hf_config: CLIPVisionConfig,
seq_len: int,
*,
Expand All @@ -123,26 +108,42 @@ def _dummy_data_for_clip(
else:
image_feature_size = image_feature_size_override

seq_data = _get_dummy_seq_data(
seq_len=seq_len,
image_token_id=image_token_id,
image_feature_size=image_feature_size,
)
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)

image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
multi_modal_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0)
multi_modal_data = ImagePixelData(image)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
depth = hf_config.hidden_size
values = torch.zeros((1, image_feature_size, depth),
dtype=torch.float16)
multi_modal_data = ImageFeatureData(values)

return seq_data, multi_modal_data
@classmethod
def _dummy_pixel_data_for_clip(
cls,
hf_config: CLIPVisionConfig,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override

image = Image.new("RGB", (width, height), color=0)
return ImagePixelData(image)

@classmethod
def _dummy_feature_data_for_clip(
cls,
hf_config: CLIPVisionConfig,
*,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = _get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override

values = torch.zeros((1, image_feature_size, hf_config.hidden_size),
dtype=torch.float16)
return ImageFeatureData(values)

@classmethod
def _dummy_data_for_llava(
Expand All @@ -155,14 +156,24 @@ def _dummy_data_for_llava(
vision_config = hf_config.vision_config

if isinstance(vision_config, CLIPVisionConfig):
return cls._dummy_data_for_clip(
model_config,
multimodal_config,
seq_data = cls._dummy_seq_data_for_clip(
vision_config,
seq_len=seq_len,
seq_len,
image_token_id=hf_config.image_token_index,
)

image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
multi_modal_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
multi_modal_data = cls._dummy_pixel_data_for_clip(
vision_config)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
multi_modal_data = cls._dummy_feature_data_for_clip(
vision_config)

return seq_data, multi_modal_data

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)

Expand All @@ -175,18 +186,37 @@ def _dummy_data_for_llava_next(
seq_len: int,
):
vision_config = hf_config.vision_config
image_feature_size = _get_llava_next_image_feature_size(hf_config)

# Result in the max possible feature size
dummy_height = dummy_width = 448
image_feature_size = _get_llava_next_image_feature_size(
hf_config, input_height=dummy_height, input_width=dummy_width)

if isinstance(vision_config, CLIPVisionConfig):
return cls._dummy_data_for_clip(
model_config,
multimodal_config,
seq_data = cls._dummy_seq_data_for_clip(
vision_config,
seq_len=seq_len,
seq_len,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)

image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
multi_modal_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
multi_modal_data = cls._dummy_pixel_data_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
multi_modal_data = cls._dummy_feature_data_for_clip(
vision_config,
image_feature_size_override=image_feature_size,
)

return seq_data, multi_modal_data

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)

Expand Down

0 comments on commit 29c3bb3

Please sign in to comment.