From 251752cc587309da7d38e6cca342852a57642cde Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 22 May 2024 23:34:09 +0800 Subject: [PATCH 01/30] init phi3v support --- examples/phi3v_example.py | 40 +++ vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/phi3v.py | 413 +++++++++++++++++++++++++ 3 files changed, 454 insertions(+) create mode 100644 examples/phi3v_example.py create mode 100644 vllm/model_executor/models/phi3v.py diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py new file mode 100644 index 000000000000..219366627ac9 --- /dev/null +++ b/examples/phi3v_example.py @@ -0,0 +1,40 @@ +from PIL import Image +import requests +from transformers import AutoProcessor + +from vllm import LLM +from vllm.sequence import MultiModalData + + +def run_phi3v(): + model_path = "/data/LLM-model/Phi-3-vision-128k-instruct" + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + llm = LLM( + model=model_path, + trust_remote_code=True, + image_input_type="pixel_values", + image_token_id=-1, + image_input_shape="1,3,336,336", + image_feature_size=1024, + ) + + url = "https://www.ilankelman.org/stopsigns/australia.jpg" + image = Image.open(requests.get(url, stream=True).raw) + user_prompt = '<|user|>\n' + assistant_prompt = '<|assistant|>\n' + prompt_suffix = "<|end|>\n" + + # single-image prompt + prompt = f"{user_prompt}<|image_1|>\nWhat is shown in this image?{prompt_suffix}{assistant_prompt}" + inputs = processor(prompt, image, return_tensors="pt") + + outputs = llm.generate(prompt_token_ids=inputs["input_ids"][0], + multi_modal_data=MultiModalData( + type=MultiModalData.Type.IMAGE, data=inputs["pixel_values"])) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +if __name__ == "__main__": + run_phi3v() diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 6aec104be8da..339ceb95a05e 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -47,6 +47,7 @@ "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), + "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py new file mode 100644 index 000000000000..c179d9427bf8 --- /dev/null +++ b/vllm/model_executor/models/phi3v.py @@ -0,0 +1,413 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union + +import math +import torch +import torch.nn as nn +from transformers import CLIPVisionModel, LlamaConfig, PretrainedConfig +from transformers import CLIPVisionConfig +from transformers.utils import logging +from datetime import datetime + +from transformers import CLIPVisionModel + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, VisionLanguageConfig +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.models.vlm_base import VisionLanguageModelBase +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput + +logger = logging.get_logger(__name__) + +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( + attention_dropout=0.0, + dropout=0.0, + hidden_act="quick_gelu", + hidden_size=1024, + image_size=336, + initializer_factor=1.0, + initializer_range=0.02, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768 +) + + +# copy from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py +class Phi3ImageEmbedding(nn.Module): + """Phi3 Image embedding.""" + + def __init__(self, config: PretrainedConfig, wte=None) -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size + if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'): + embd_drop = config.embd_pdrop if hasattr(config, 'embd_pdrop') else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + self.wte = wte + + if isinstance(config.img_processor, dict) and config.img_processor.get('name', None) == 'clip_vision_model': + assert 'model_name' in config.img_processor, 'model_name must be provided for CLIPVisionModel' + assert 'image_dim_out' in config.img_processor, 'image_dim_out must be provided for CLIPVisionModel' + assert 'num_img_tokens' in config.img_processor, 'num_img_tokens must be provided for CLIPVisionModel' + assert config.img_processor['model_name'] == 'openai/clip-vit-large-patch14-336' + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + self.img_processor = CLIPVisionModel(clip_config) + image_dim_out = config.img_processor['image_dim_out'] + self.num_img_tokens = config.img_processor['num_img_tokens'] + else: + raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented') + + self.image_dim_out = image_dim_out + self.img_sizes = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = config.embd_layer.get('use_hd_transform', False) + self.with_learnable_separator = config.embd_layer.get('with_learnable_separator', False) + self.hd_transform_order = config.embd_layer.get('hd_transform_order', 'glb_sub') + # with_hd_transform and with_learnable_separator should have same value + assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value' + if self.with_learnable_separator: + assert self.use_hd_transform, 'learnable separator is only for hd transform' + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * 4])) + logger.info(f'learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}') + + projection_cls = config.embd_layer.get('projection_cls', 'linear') + if projection_cls == 'linear': + self.img_projection = nn.Linear(image_dim_out, hidden_size) + elif projection_cls == 'mlp' and self.use_hd_transform: + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * 4, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + elif projection_cls == 'mlp': + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + else: + raise NotImplementedError(f'projection_cls = {projection_cls}, not implemented') + + self.vocab_size = config.vocab_size + self.img_features = None + + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get('type_feature', 'patch') + else: + self.layer_idx = -2 + self.type_feature = 'patch' + + + def set_img_features(self, img_features: torch.FloatTensor) -> None: + self.img_features = img_features + + def set_img_sizes(self, img_sizes: torch.LongTensor) -> None: + self.img_sizes = img_sizes + + def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == "patch": + patch_feature = img_feature[:, 1:] + return patch_feature + + if TYPE_FEATURE == "cls_patch": + return img_feature + + raise NotImplementedError + + def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None, **kwargs) -> torch.FloatTensor: + + MAX_INPUT_ID = int(1e9) + img_embeds = pixel_values + img_sizes = image_sizes + + if self.img_features is not None: + img_embeds = self.img_features.clone() + self.img_features = None + + if self.img_sizes is not None: + img_sizes = self.img_sizes + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + with torch.no_grad(): + positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False) + + select = False + + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + if len(positions.tolist()) > 0: + with torch.no_grad(): + g_values = abs(input_ids[positions[:, 0], positions[:, 1]]) + + if self.use_hd_transform and img_sizes is not None and len(img_sizes): + hd_transform = True + assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' + # img_embeds: (num_images, max_num_crops, 3, H, W) + # img_sizes: (num_images, 2).view(1, -1) + + start_time = datetime.now() + bs = img_embeds.shape[0] + # Nx(HW)xC + img_features = self.get_img_features(img_embeds.flatten(0, 1)) + base_feat_height = base_feat_width = int(img_features.shape[1] ** 0.5) + + assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + # training is tensor, inference is list + if isinstance(img_sizes, torch.Tensor): + img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): + h, w = img_sizes[_bs] + h = h // 336 + w = w // 336 + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = global_img_feature.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C).contiguous() + temp_glb_GN = self.sub_GN.repeat(1, H//2, 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,4*C) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) + sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,4*C).contiguous() + sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C) + temp_sub_GN = self.sub_GN.repeat(1, h*12, 1, 1) + sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,4*C) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == 'glb_sub': + output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == 'sub_glb': + output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + else: + raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented') + + temp_len = int((h*w+1)*144 + 1 + (h+1)*12) + assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' + output_len.append(temp_len) + + num_img_tokens = output_len + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype)) + img_set_tensor.append(img_feature_proj) + logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}') + elif img_embeds.ndim == 4: + selected_g_values = g_values[::self.num_img_tokens] + assert len(img_embeds) == len(selected_g_values), f'img_embeds size: {img_embeds.size()}, selected_g_values size: {len(selected_g_values)}, selected_g_value {selected_g_values}' + start_time = datetime.now() + tt = ( + self.get_img_features(img_embeds) + .to(target_device) + .to(target_dtype) + .reshape(-1, self.image_dim_out) + ) + logger.info(f'img_embeds size: {img_embeds.size()}, loading time {datetime.now() - start_time}') + img_set_tensor = self.img_projection(tt) # adapted visual features. + elif img_embeds.ndim == 3: + selected_g_values = g_values[::self.num_img_tokens] + assert len(img_embeds) == len(selected_g_values), f'img_embeds size: {img_embeds.size()}, selected_g_values size: {len(selected_g_values)}, selected_g_value {selected_g_values}' + tt = ( + img_embeds + .to(target_device) + .to(target_dtype) + .view(-1, self.image_dim_out) + ) + img_set_tensor = self.img_projection(tt) # adapted visual features. + else: + raise NotImplementedError + select = True + + with torch.no_grad(): + input_ids.clamp_min_(0).clamp_max_(self.vocab_size) + + hidden_states = self.wte(input_ids) + + if select: + if hd_transform: + idx = 0 + for i, cnt in enumerate(num_img_tokens): + hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ( + img_set_tensor[i] + .to(hidden_states.dtype) + .to(hidden_states.device) + ) + idx += cnt + else: + idx = 0 + assert len(selected_g_values) * self.num_img_tokens == len(img_set_tensor), f'len(selected_g_values) * self.num_img_tokens = {len(selected_g_values) * self.num_img_tokens}, len(img_set_tensor) = {len(img_set_tensor)}' + for i, g in enumerate(selected_g_values): + cnt = self.num_img_tokens + hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ( + img_set_tensor[i * cnt : (i + 1) * cnt] + .to(hidden_states.dtype) + .to(hidden_states.device) + ) + idx += cnt + + if self.drop is not None: + hidden_states = self.drop(hidden_states) + + return hidden_states + + +class Phi3VForCausalLM(VisionLanguageModelBase): + def __init__(self, + config: PretrainedConfig, + vision_language_config: VisionLanguageConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__(vision_language_config) + self.config = config + self.model = LlamaModel(config, cache_config, quant_config) + self.vision_embed_tokens = Phi3ImageEmbedding(config, self.model.embed_tokens) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + image_input: Optional[dict] = None): + + if image_input is not None: + inputs_embeds = self.vision_embed_tokens(input_ids, **image_input) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.model(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if "model.vision_embed_tokens" in name: + name = name.replace("model.vision_embed_tokens", "vision_embed_tokens") + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if "vision_embed_tokens" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file From 70e7017ff41d187887fadcc1f8c0f5b0ca6e4aee Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 23 May 2024 17:52:04 +0800 Subject: [PATCH 02/30] make phi3v work --- examples/phi3v_example.py | 13 +++++++++---- vllm/model_executor/models/phi3v.py | 15 ++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py index 219366627ac9..a1d5e63fcb11 100644 --- a/examples/phi3v_example.py +++ b/examples/phi3v_example.py @@ -1,3 +1,5 @@ +import os + from PIL import Image import requests from transformers import AutoProcessor @@ -6,15 +8,18 @@ from vllm.sequence import MultiModalData +# os.environ["VLLM_CPU_KVCACHE_SPACE"] = "10" + def run_phi3v(): model_path = "/data/LLM-model/Phi-3-vision-128k-instruct" processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) llm = LLM( model=model_path, trust_remote_code=True, + max_model_len=4096, image_input_type="pixel_values", image_token_id=-1, - image_input_shape="1,3,336,336", + image_input_shape="1008, 1344", image_feature_size=1024, ) @@ -27,10 +32,10 @@ def run_phi3v(): # single-image prompt prompt = f"{user_prompt}<|image_1|>\nWhat is shown in this image?{prompt_suffix}{assistant_prompt}" inputs = processor(prompt, image, return_tensors="pt") + multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=inputs["pixel_values"]) - outputs = llm.generate(prompt_token_ids=inputs["input_ids"][0], - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=inputs["pixel_values"])) + outputs = llm.generate(prompt_token_ids=inputs["input_ids"].tolist(), multi_modal_data=multi_modal_data) + # outputs = llm.generate(prompt) for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index c179d9427bf8..9c7387989e2f 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -58,7 +58,7 @@ ) -# copy from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py +# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py class Phi3ImageEmbedding(nn.Module): """Phi3 Image embedding.""" @@ -158,7 +158,7 @@ def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: raise NotImplementedError - def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None, **kwargs) -> torch.FloatTensor: + def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None) -> torch.FloatTensor: MAX_INPUT_ID = int(1e9) img_embeds = pixel_values @@ -215,7 +215,7 @@ def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, if isinstance(img_sizes, torch.Tensor): img_sizes = img_sizes.view(-1, 2) for _bs in range(bs): - h, w = img_sizes[_bs] + h, w = img_sizes h = h // 336 w = w // 336 B_ = h * w @@ -314,10 +314,7 @@ def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, ) idx += cnt - if self.drop is not None: - hidden_states = self.drop(hidden_states) - - return hidden_states + return hidden_states.squeeze(0) class Phi3VForCausalLM(VisionLanguageModelBase): @@ -340,9 +337,9 @@ def forward(self, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, image_input: Optional[dict] = None): - if image_input is not None: - inputs_embeds = self.vision_embed_tokens(input_ids, **image_input) + print(image_input.shape) + inputs_embeds = self.vision_embed_tokens(input_ids, image_input, self.vision_language_config.image_input_shape) input_ids = None else: From 618a2cbc90126beedf6daf76a8890ab321f0afdd Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 23 May 2024 17:53:37 +0800 Subject: [PATCH 03/30] remove debug code --- examples/phi3v_example.py | 2 +- vllm/model_executor/models/phi3v.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py index a1d5e63fcb11..4b9bdf7a837b 100644 --- a/examples/phi3v_example.py +++ b/examples/phi3v_example.py @@ -16,7 +16,7 @@ def run_phi3v(): llm = LLM( model=model_path, trust_remote_code=True, - max_model_len=4096, + max_model_len=8192, image_input_type="pixel_values", image_token_id=-1, image_input_shape="1008, 1344", diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 9c7387989e2f..3387d2f3174a 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -338,7 +338,6 @@ def forward(self, attn_metadata: AttentionMetadata, image_input: Optional[dict] = None): if image_input is not None: - print(image_input.shape) inputs_embeds = self.vision_embed_tokens(input_ids, image_input, self.vision_language_config.image_input_shape) input_ids = None From ffb32fb52c4eda715d86aa63afc757f19315e801 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 24 May 2024 15:05:08 +0800 Subject: [PATCH 04/30] remove dropout from Phi3ImageEmbedding --- vllm/model_executor/models/phi3v.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 3387d2f3174a..212e594d468b 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -67,11 +67,6 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None: # n_embed or hidden_size hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size - if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'): - embd_drop = config.embd_pdrop if hasattr(config, 'embd_pdrop') else config.embed_pdrop - self.drop = nn.Dropout(embd_drop) - else: - self.drop = None self.wte = wte @@ -99,8 +94,8 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None: if self.with_learnable_separator: assert self.use_hd_transform, 'learnable separator is only for hd transform' # 1024 * 4, merge spatial to channel dimension - self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * 4])) - self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * 4])) + self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4])) logger.info(f'learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}') projection_cls = config.embd_layer.get('projection_cls', 'linear') @@ -211,9 +206,7 @@ def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, output_imgs = [] output_len = [] - # training is tensor, inference is list - if isinstance(img_sizes, torch.Tensor): - img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): h, w = img_sizes h = h // 336 From 76e6f8e3c71b545ec2fcab841dd7101b89506bbc Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 24 May 2024 17:25:09 +0800 Subject: [PATCH 05/30] clean code --- examples/phi3v_example.py | 23 +-- vllm/model_executor/models/phi3v.py | 254 ++++++++++++++-------------- 2 files changed, 143 insertions(+), 134 deletions(-) diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py index 4b9bdf7a837b..fb2a119797cd 100644 --- a/examples/phi3v_example.py +++ b/examples/phi3v_example.py @@ -1,22 +1,20 @@ -import os - -from PIL import Image import requests +from PIL import Image from transformers import AutoProcessor from vllm import LLM from vllm.sequence import MultiModalData -# os.environ["VLLM_CPU_KVCACHE_SPACE"] = "10" - def run_phi3v(): model_path = "/data/LLM-model/Phi-3-vision-128k-instruct" - processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model_path, + trust_remote_code=True) llm = LLM( model=model_path, trust_remote_code=True, - max_model_len=8192, + dtype='bfloat16', + max_model_len=4096, image_input_type="pixel_values", image_token_id=-1, image_input_shape="1008, 1344", @@ -27,14 +25,17 @@ def run_phi3v(): image = Image.open(requests.get(url, stream=True).raw) user_prompt = '<|user|>\n' assistant_prompt = '<|assistant|>\n' - prompt_suffix = "<|end|>\n" + suffix = "<|end|>\n" # single-image prompt - prompt = f"{user_prompt}<|image_1|>\nWhat is shown in this image?{prompt_suffix}{assistant_prompt}" + prompt = "What is shown in this image?" + prompt = f"{user_prompt}<|image_1|>\n{prompt}{suffix}{assistant_prompt}" inputs = processor(prompt, image, return_tensors="pt") - multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=inputs["pixel_values"]) + multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, + data=inputs["pixel_values"]) - outputs = llm.generate(prompt_token_ids=inputs["input_ids"].tolist(), multi_modal_data=multi_modal_data) + outputs = llm.generate(prompt_token_ids=inputs["input_ids"].tolist(), + multi_modal_data=multi_modal_data) # outputs = llm.generate(prompt) for o in outputs: generated_text = o.outputs[0].text diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 212e594d468b..3d373f4a499a 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -1,4 +1,5 @@ # coding=utf-8 +# Copyright 2024 The vLLM team. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,21 +13,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import Iterable, List, Optional, Tuple -import math import torch import torch.nn as nn -from transformers import CLIPVisionModel, LlamaConfig, PretrainedConfig -from transformers import CLIPVisionConfig +from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig from transformers.utils import logging -from datetime import datetime - -from transformers import CLIPVisionModel from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VisionLanguageConfig -from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -40,22 +35,16 @@ logger = logging.get_logger(__name__) -CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( - attention_dropout=0.0, - dropout=0.0, - hidden_act="quick_gelu", - hidden_size=1024, - image_size=336, - initializer_factor=1.0, - initializer_range=0.02, - intermediate_size=4096, - layer_norm_eps=1e-05, - num_attention_heads=16, - num_channels=3, - num_hidden_layers=24, - patch_size=14, - projection_dim=768 -) +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, + hidden_act="quick_gelu", + hidden_size=1024, + image_size=336, + intermediate_size=4096, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768) # adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py @@ -66,37 +55,42 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None: super().__init__() # n_embed or hidden_size - hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size self.wte = wte - if isinstance(config.img_processor, dict) and config.img_processor.get('name', None) == 'clip_vision_model': - assert 'model_name' in config.img_processor, 'model_name must be provided for CLIPVisionModel' - assert 'image_dim_out' in config.img_processor, 'image_dim_out must be provided for CLIPVisionModel' - assert 'num_img_tokens' in config.img_processor, 'num_img_tokens must be provided for CLIPVisionModel' - assert config.img_processor['model_name'] == 'openai/clip-vit-large-patch14-336' + if isinstance(config.img_processor, dict) and config.img_processor.get( + 'name', None) == 'clip_vision_model': clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG self.img_processor = CLIPVisionModel(clip_config) image_dim_out = config.img_processor['image_dim_out'] self.num_img_tokens = config.img_processor['num_img_tokens'] else: - raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented') + raise NotImplementedError( + f'img_processor = {config.img_processor}, not implemented') self.image_dim_out = image_dim_out self.img_sizes = None # global_gn and sub_gn for hd transform, serves as line separator - self.use_hd_transform = config.embd_layer.get('use_hd_transform', False) - self.with_learnable_separator = config.embd_layer.get('with_learnable_separator', False) - self.hd_transform_order = config.embd_layer.get('hd_transform_order', 'glb_sub') + self.use_hd_transform = config.embd_layer.get('use_hd_transform', + False) + self.with_learnable_separator = config.embd_layer.get( + 'with_learnable_separator', False) + self.hd_transform_order = config.embd_layer.get( + 'hd_transform_order', 'glb_sub') # with_hd_transform and with_learnable_separator should have same value - assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value' + assert self.use_hd_transform == self.with_learnable_separator if self.with_learnable_separator: - assert self.use_hd_transform, 'learnable separator is only for hd transform' # 1024 * 4, merge spatial to channel dimension - self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4])) - self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4])) - logger.info(f'learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}') + self.glb_GN = nn.Parameter( + torch.empty([1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter( + torch.empty([1, 1, 1, self.image_dim_out * 4])) + logger.info( + 'learnable separator enabled for hd transform' + 'hd_transform_order = %s', self.hd_transform_order) projection_cls = config.embd_layer.get('projection_cls', 'linear') if projection_cls == 'linear': @@ -106,42 +100,47 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None: depth = 2 layers = [nn.Linear(image_dim_out * 4, dim_projection)] for _ in range(1, depth): - layers.extend([nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) elif projection_cls == 'mlp': dim_projection = hidden_size depth = 2 layers = [nn.Linear(image_dim_out, dim_projection)] for _ in range(1, depth): - layers.extend([nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) else: - raise NotImplementedError(f'projection_cls = {projection_cls}, not implemented') + raise NotImplementedError( + f'projection_cls = {projection_cls}, not implemented') self.vocab_size = config.vocab_size self.img_features = None if isinstance(config.img_processor, dict): self.layer_idx = config.img_processor.get('layer_idx', -2) - self.type_feature = config.img_processor.get('type_feature', 'patch') + self.type_feature = config.img_processor.get( + 'type_feature', 'patch') else: self.layer_idx = -2 self.type_feature = 'patch' - def set_img_features(self, img_features: torch.FloatTensor) -> None: self.img_features = img_features def set_img_sizes(self, img_sizes: torch.LongTensor) -> None: self.img_sizes = img_sizes - def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: + def get_img_features(self, + img_embeds: torch.FloatTensor) -> torch.FloatTensor: LAYER_IDX = self.layer_idx TYPE_FEATURE = self.type_feature - img_processor_output = self.img_processor(img_embeds, output_hidden_states=True) + img_processor_output = self.img_processor(img_embeds, + output_hidden_states=True) img_feature = img_processor_output.hidden_states[LAYER_IDX] if TYPE_FEATURE == "patch": @@ -153,7 +152,10 @@ def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: raise NotImplementedError - def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None) -> torch.FloatTensor: + def forward(self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + image_sizes=None) -> torch.FloatTensor: MAX_INPUT_ID = int(1e9) img_embeds = pixel_values @@ -170,37 +172,38 @@ def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, input_ids = input_ids.view(-1, input_shape[-1]) with torch.no_grad(): - positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False) - + positions = torch.nonzero( + (input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False) + select = False - if isinstance(self.img_projection, nn.Sequential): - target_device = self.img_projection[0].bias.device - target_dtype = self.img_projection[0].bias.dtype - else: # It's a single nn.Linear layer - target_device = self.img_projection.bias.device - target_dtype = self.img_projection.bias.dtype + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype if len(positions.tolist()) > 0: with torch.no_grad(): g_values = abs(input_ids[positions[:, 0], positions[:, 1]]) - if self.use_hd_transform and img_sizes is not None and len(img_sizes): + if self.use_hd_transform and img_sizes is not None and len( + img_sizes): hd_transform = True - assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # img_embeds: (num_images, max_num_crops, 3, H, W) # img_sizes: (num_images, 2).view(1, -1) - start_time = datetime.now() bs = img_embeds.shape[0] # Nx(HW)xC img_features = self.get_img_features(img_embeds.flatten(0, 1)) - base_feat_height = base_feat_width = int(img_features.shape[1] ** 0.5) - - assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' + base_feat_height = base_feat_width = int( + img_features.shape[1]**0.5) # bs x max_num_crops x (24x24) x C - img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + img_features = img_features.view( + bs, -1, base_feat_height * base_feat_width, + self.image_dim_out) C = self.image_dim_out H = base_feat_height @@ -209,7 +212,7 @@ def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, for _bs in range(bs): h, w = img_sizes - h = h // 336 + h = h // 336 w = w // 336 B_ = h * w @@ -217,11 +220,15 @@ def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, global_img_feature = img_features[_bs, :1] # 1 x 12 x 12 x 4096 - glb_img = global_img_feature.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C).contiguous() - temp_glb_GN = self.sub_GN.repeat(1, H//2, 1, 1) + glb_img = global_img_feature.reshape(1, H, H, C).reshape( + 1, H // 2, 2, H // 2, 2, + C).permute(0, 1, 3, 2, 4, + 5).reshape(1, H // 2, H // 2, 4 * C) + temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) # 1 x 156 x 4096 - glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,4*C) + glb_img = torch.cat([glb_img, temp_glb_GN], + dim=2).reshape(1, -1, 4 * C) # (max_num_crops-1) x (12x12) x C sub_img = img_features[_bs, 1:] @@ -229,89 +236,86 @@ def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, # get rid of padding sub_img sub_img = sub_img[:B_] - # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) - sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,4*C).contiguous() - sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C) - temp_sub_GN = self.sub_GN.repeat(1, h*12, 1, 1) - sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,4*C) + sub_img = sub_img.reshape(B_, H, H, C).reshape( + B_, H // 2, 2, H // 2, 2, + C).permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C) + sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute( + 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C) + temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) + sub_img = torch.cat([sub_img, temp_sub_GN], + dim=2).reshape(1, -1, 4 * C) # (1, num_img_tokens, 1024*4) # glb + sub if self.hd_transform_order == 'glb_sub': - output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + output_imgs.append( + torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) elif self.hd_transform_order == 'sub_glb': - output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + output_imgs.append( + torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) else: - raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented') + raise NotImplementedError( + f'hd_transform_order = {self.hd_transform_order},' + 'not implemented') - temp_len = int((h*w+1)*144 + 1 + (h+1)*12) - assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' + temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) output_len.append(temp_len) - + num_img_tokens = output_len img_set_tensor = [] for _output_img in output_imgs: - img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype)) + img_feature_proj = self.img_projection( + _output_img.to(target_device).to(target_dtype)) img_set_tensor.append(img_feature_proj) - logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}') elif img_embeds.ndim == 4: selected_g_values = g_values[::self.num_img_tokens] - assert len(img_embeds) == len(selected_g_values), f'img_embeds size: {img_embeds.size()}, selected_g_values size: {len(selected_g_values)}, selected_g_value {selected_g_values}' - start_time = datetime.now() - tt = ( - self.get_img_features(img_embeds) - .to(target_device) - .to(target_dtype) - .reshape(-1, self.image_dim_out) - ) - logger.info(f'img_embeds size: {img_embeds.size()}, loading time {datetime.now() - start_time}') - img_set_tensor = self.img_projection(tt) # adapted visual features. + tt = (self.get_img_features(img_embeds).to(target_device).to( + target_dtype).reshape(-1, self.image_dim_out)) + img_set_tensor = self.img_projection( + tt) # adapted visual features. elif img_embeds.ndim == 3: selected_g_values = g_values[::self.num_img_tokens] - assert len(img_embeds) == len(selected_g_values), f'img_embeds size: {img_embeds.size()}, selected_g_values size: {len(selected_g_values)}, selected_g_value {selected_g_values}' - tt = ( - img_embeds - .to(target_device) - .to(target_dtype) - .view(-1, self.image_dim_out) - ) - img_set_tensor = self.img_projection(tt) # adapted visual features. + tt = (img_embeds.to(target_device).to(target_dtype).view( + -1, self.image_dim_out)) + img_set_tensor = self.img_projection( + tt) # adapted visual features. else: raise NotImplementedError select = True - + with torch.no_grad(): input_ids.clamp_min_(0).clamp_max_(self.vocab_size) - + hidden_states = self.wte(input_ids) if select: if hd_transform: idx = 0 for i, cnt in enumerate(num_img_tokens): - hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ( - img_set_tensor[i] - .to(hidden_states.dtype) - .to(hidden_states.device) - ) + hidden_states[positions[idx, 0], + positions[idx, 1]:positions[idx, 1] + + cnt] = (img_set_tensor[i].to( + hidden_states.dtype).to( + hidden_states.device)) idx += cnt else: idx = 0 - assert len(selected_g_values) * self.num_img_tokens == len(img_set_tensor), f'len(selected_g_values) * self.num_img_tokens = {len(selected_g_values) * self.num_img_tokens}, len(img_set_tensor) = {len(img_set_tensor)}' for i, g in enumerate(selected_g_values): cnt = self.num_img_tokens - hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ( - img_set_tensor[i * cnt : (i + 1) * cnt] - .to(hidden_states.dtype) - .to(hidden_states.device) - ) + hidden_states[positions[idx, 0], + positions[idx, 1]:positions[idx, 1] + + cnt] = ( + img_set_tensor[i * cnt:(i + 1) * cnt].to( + hidden_states.dtype).to( + hidden_states.device)) idx += cnt return hidden_states.squeeze(0) class Phi3VForCausalLM(VisionLanguageModelBase): - def __init__(self, + + def __init__(self, config: PretrainedConfig, vision_language_config: VisionLanguageConfig, cache_config: Optional[CacheConfig] = None, @@ -319,11 +323,12 @@ def __init__(self, super().__init__(vision_language_config) self.config = config self.model = LlamaModel(config, cache_config, quant_config) - self.vision_embed_tokens = Phi3ImageEmbedding(config, self.model.embed_tokens) + self.vision_embed_tokens = Phi3ImageEmbedding(config, + self.model.embed_tokens) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -331,17 +336,19 @@ def forward(self, attn_metadata: AttentionMetadata, image_input: Optional[dict] = None): if image_input is not None: - inputs_embeds = self.vision_embed_tokens(input_ids, image_input, self.vision_language_config.image_input_shape) + inputs_embeds = self.vision_embed_tokens( + input_ids, image_input, + self.vision_language_config.image_input_shape) input_ids = None else: inputs_embeds = None hidden_states = self.model(input_ids, - positions, - kv_caches, - attn_metadata, - inputs_embeds=inputs_embeds) + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) return hidden_states @@ -358,7 +365,7 @@ def sample( ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -378,7 +385,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # the checkpoint. Skip them. continue if "model.vision_embed_tokens" in name: - name = name.replace("model.vision_embed_tokens", "vision_embed_tokens") + name = name.replace("model.vision_embed_tokens", + "vision_embed_tokens") for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue @@ -399,4 +407,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) \ No newline at end of file + weight_loader(param, loaded_weight) From 58330bafcd82227938e212cb6efbb27507d2a6b2 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 24 May 2024 23:05:38 +0800 Subject: [PATCH 06/30] optimize code structure --- vllm/model_executor/models/phi3v.py | 59 +++++++++++++---------------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 3d373f4a499a..4b22a3e44c00 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -188,8 +188,7 @@ def forward(self, with torch.no_grad(): g_values = abs(input_ids[positions[:, 0], positions[:, 1]]) - if self.use_hd_transform and img_sizes is not None and len( - img_sizes): + if self.use_hd_transform and img_sizes: hd_transform = True # img_embeds: (num_images, max_num_crops, 3, H, W) # img_sizes: (num_images, 2).view(1, -1) @@ -220,10 +219,10 @@ def forward(self, global_img_feature = img_features[_bs, :1] # 1 x 12 x 12 x 4096 - glb_img = global_img_feature.reshape(1, H, H, C).reshape( - 1, H // 2, 2, H // 2, 2, - C).permute(0, 1, 3, 2, 4, - 5).reshape(1, H // 2, H // 2, 4 * C) + glb_img = global_img_feature \ + .reshape(1, H // 2, 2, H // 2, 2,C) \ + .permute(0, 1, 3, 2, 4, 5) \ + .reshape(1, H // 2, H // 2, 4 * C) temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) # 1 x 156 x 4096 @@ -236,11 +235,11 @@ def forward(self, # get rid of padding sub_img sub_img = sub_img[:B_] - sub_img = sub_img.reshape(B_, H, H, C).reshape( - B_, H // 2, 2, H // 2, 2, - C).permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C) - sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute( - 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C) + sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \ + .permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C) + sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \ + .permute(0, 1, 3, 2, 4, 5) \ + .reshape(1, h * 12, w * 12, 4 * C) temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1, -1, 4 * C) @@ -288,27 +287,23 @@ def forward(self, hidden_states = self.wte(input_ids) - if select: - if hd_transform: - idx = 0 - for i, cnt in enumerate(num_img_tokens): - hidden_states[positions[idx, 0], - positions[idx, 1]:positions[idx, 1] + - cnt] = (img_set_tensor[i].to( - hidden_states.dtype).to( - hidden_states.device)) - idx += cnt - else: - idx = 0 - for i, g in enumerate(selected_g_values): - cnt = self.num_img_tokens - hidden_states[positions[idx, 0], - positions[idx, 1]:positions[idx, 1] + - cnt] = ( - img_set_tensor[i * cnt:(i + 1) * cnt].to( - hidden_states.dtype).to( - hidden_states.device)) - idx += cnt + if select and hd_transform: + idx = 0 + for i, cnt in enumerate(num_img_tokens): + hidden_states[positions[idx, 0], + positions[idx, 1]:positions[idx, 1] + + cnt] = (img_set_tensor[i].to( + hidden_states.device, hidden_states.dtype)) + idx += cnt + elif select: + idx = 0 + for i, g in enumerate(selected_g_values): + cnt = self.num_img_tokens + hidden_states[positions[idx, 0], + positions[idx, 1]:positions[idx, 1] + + cnt] = (img_set_tensor[i * cnt:(i + 1) * cnt].to( + hidden_states.device, hidden_states.dtype)) + idx += cnt return hidden_states.squeeze(0) From b61c2be3fffb71ca43d0922b7f9b25426125d376 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 25 May 2024 00:56:52 +0800 Subject: [PATCH 07/30] Add Phi3VImagePixelInputs --- vllm/model_executor/models/phi3v.py | 37 ++++++++++++++++------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4b22a3e44c00..7346bb251544 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict import torch import torch.nn as nn @@ -35,6 +35,10 @@ logger = logging.get_logger(__name__) +_KEYS_TO_MODIFY_MAPPING = { + "model.vision_embed_tokens": "vision_embed_tokens", +} + CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, hidden_act="quick_gelu", hidden_size=1024, @@ -308,6 +312,12 @@ def forward(self, return hidden_states.squeeze(0) +class Phi3VImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channels, height, width)""" + + class Phi3VForCausalLM(VisionLanguageModelBase): def __init__(self, @@ -329,7 +339,7 @@ def forward(self, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - image_input: Optional[dict] = None): + image_input: Phi3VImagePixelInputs = None): if image_input is not None: inputs_embeds = self.vision_embed_tokens( input_ids, image_input, @@ -374,24 +384,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if "model.vision_embed_tokens" in name: - name = name.replace("model.vision_embed_tokens", - "vision_embed_tokens") + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: + # We only do sharding for language model + # and not vision model for now. + if "vision_embed_tokens" in name and self.vision_embed_tokens: continue - if "vision_embed_tokens" in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if weight_name not in name: continue - param = params_dict[name] + param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break From 1e3e18c1e00a4c6c5d4d0c9fad109129f37b2229 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 26 May 2024 14:05:55 +0800 Subject: [PATCH 08/30] refactor image embedding --- vllm/model_executor/models/phi3v.py | 333 ++++++++++++---------------- 1 file changed, 138 insertions(+), 195 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 7346bb251544..d7ebe7106977 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -51,86 +51,14 @@ projection_dim=768) -# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py -class Phi3ImageEmbedding(nn.Module): - """Phi3 Image embedding.""" +class Phi3ImageEmbeddingBase(nn.Module): - def __init__(self, config: PretrainedConfig, wte=None) -> None: + def __init__(self, wte=None) -> None: super().__init__() - - # n_embed or hidden_size - hidden_size = config.n_embd if hasattr( - config, 'n_embd') else config.hidden_size - self.wte = wte - - if isinstance(config.img_processor, dict) and config.img_processor.get( - 'name', None) == 'clip_vision_model': - clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG - self.img_processor = CLIPVisionModel(clip_config) - image_dim_out = config.img_processor['image_dim_out'] - self.num_img_tokens = config.img_processor['num_img_tokens'] - else: - raise NotImplementedError( - f'img_processor = {config.img_processor}, not implemented') - - self.image_dim_out = image_dim_out - self.img_sizes = None - - # global_gn and sub_gn for hd transform, serves as line separator - self.use_hd_transform = config.embd_layer.get('use_hd_transform', - False) - self.with_learnable_separator = config.embd_layer.get( - 'with_learnable_separator', False) - self.hd_transform_order = config.embd_layer.get( - 'hd_transform_order', 'glb_sub') - # with_hd_transform and with_learnable_separator should have same value - assert self.use_hd_transform == self.with_learnable_separator - if self.with_learnable_separator: - # 1024 * 4, merge spatial to channel dimension - self.glb_GN = nn.Parameter( - torch.empty([1, 1, self.image_dim_out * 4])) - self.sub_GN = nn.Parameter( - torch.empty([1, 1, 1, self.image_dim_out * 4])) - logger.info( - 'learnable separator enabled for hd transform' - 'hd_transform_order = %s', self.hd_transform_order) - - projection_cls = config.embd_layer.get('projection_cls', 'linear') - if projection_cls == 'linear': - self.img_projection = nn.Linear(image_dim_out, hidden_size) - elif projection_cls == 'mlp' and self.use_hd_transform: - dim_projection = hidden_size - depth = 2 - layers = [nn.Linear(image_dim_out * 4, dim_projection)] - for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) - self.img_projection = nn.Sequential(*layers) - elif projection_cls == 'mlp': - dim_projection = hidden_size - depth = 2 - layers = [nn.Linear(image_dim_out, dim_projection)] - for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) - self.img_projection = nn.Sequential(*layers) - else: - raise NotImplementedError( - f'projection_cls = {projection_cls}, not implemented') - - self.vocab_size = config.vocab_size - self.img_features = None - - if isinstance(config.img_processor, dict): - self.layer_idx = config.img_processor.get('layer_idx', -2) - self.type_feature = config.img_processor.get( - 'type_feature', 'patch') - else: - self.layer_idx = -2 - self.type_feature = 'patch' + self.layer_idx: int + self.type_feature: str + self.img_processor: CLIPVisionModel def set_img_features(self, img_features: torch.FloatTensor) -> None: self.img_features = img_features @@ -156,6 +84,59 @@ def get_img_features(self, raise NotImplementedError + +# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py +class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): + """Phi3 Image embedding with HD transform.""" + + def __init__(self, config: PretrainedConfig, wte=None) -> None: + super().__init__(wte) + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size + + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + self.img_processor = CLIPVisionModel(clip_config) + image_dim_out = config.img_processor['image_dim_out'] + self.num_img_tokens = config.img_processor['num_img_tokens'] + + self.image_dim_out = image_dim_out + self.img_sizes = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = config.embd_layer.get('use_hd_transform', + False) + self.with_learnable_separator = config.embd_layer.get( + 'with_learnable_separator', False) + self.hd_transform_order = config.embd_layer.get( + 'hd_transform_order', 'glb_sub') + # with_hd_transform and with_learnable_separator should have same value + assert self.use_hd_transform & self.with_learnable_separator + + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter( + torch.empty([1, 1, 1, self.image_dim_out * 4])) + logger.info( + 'learnable separator enabled for hd transform' + 'hd_transform_order = %s', self.hd_transform_order) + + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * 4, dim_projection)] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + + self.vocab_size = config.vocab_size + self.img_features = None + + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get('type_feature', 'patch') + def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, @@ -175,123 +156,94 @@ def forward(self, input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - with torch.no_grad(): - positions = torch.nonzero( - (input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False) + positions = torch.nonzero( + (input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False) select = False - if isinstance(self.img_projection, nn.Sequential): - target_device = self.img_projection[0].bias.device - target_dtype = self.img_projection[0].bias.dtype - else: # It's a single nn.Linear layer - target_device = self.img_projection.bias.device - target_dtype = self.img_projection.bias.dtype + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype if len(positions.tolist()) > 0: - with torch.no_grad(): - g_values = abs(input_ids[positions[:, 0], positions[:, 1]]) - - if self.use_hd_transform and img_sizes: - hd_transform = True - # img_embeds: (num_images, max_num_crops, 3, H, W) - # img_sizes: (num_images, 2).view(1, -1) - - bs = img_embeds.shape[0] - # Nx(HW)xC - img_features = self.get_img_features(img_embeds.flatten(0, 1)) - base_feat_height = base_feat_width = int( - img_features.shape[1]**0.5) - - # bs x max_num_crops x (24x24) x C - img_features = img_features.view( - bs, -1, base_feat_height * base_feat_width, - self.image_dim_out) - C = self.image_dim_out - H = base_feat_height - - output_imgs = [] - output_len = [] - - for _bs in range(bs): - h, w = img_sizes - h = h // 336 - w = w // 336 - B_ = h * w - - # 1 x (24x24) x 1024 - global_img_feature = img_features[_bs, :1] - - # 1 x 12 x 12 x 4096 - glb_img = global_img_feature \ - .reshape(1, H // 2, 2, H // 2, 2,C) \ - .permute(0, 1, 3, 2, 4, 5) \ - .reshape(1, H // 2, H // 2, 4 * C) - temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) - - # 1 x 156 x 4096 - glb_img = torch.cat([glb_img, temp_glb_GN], - dim=2).reshape(1, -1, 4 * C) - - # (max_num_crops-1) x (12x12) x C - sub_img = img_features[_bs, 1:] - # 16x574x1024 - # get rid of padding sub_img - sub_img = sub_img[:B_] - - sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \ - .permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C) - sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \ - .permute(0, 1, 3, 2, 4, 5) \ - .reshape(1, h * 12, w * 12, 4 * C) - temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) - sub_img = torch.cat([sub_img, temp_sub_GN], - dim=2).reshape(1, -1, 4 * C) - # (1, num_img_tokens, 1024*4) - - # glb + sub - if self.hd_transform_order == 'glb_sub': - output_imgs.append( - torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) - elif self.hd_transform_order == 'sub_glb': - output_imgs.append( - torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) - else: - raise NotImplementedError( - f'hd_transform_order = {self.hd_transform_order},' - 'not implemented') - - temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) - output_len.append(temp_len) - - num_img_tokens = output_len - img_set_tensor = [] - for _output_img in output_imgs: - img_feature_proj = self.img_projection( - _output_img.to(target_device).to(target_dtype)) - img_set_tensor.append(img_feature_proj) - elif img_embeds.ndim == 4: - selected_g_values = g_values[::self.num_img_tokens] - tt = (self.get_img_features(img_embeds).to(target_device).to( - target_dtype).reshape(-1, self.image_dim_out)) - img_set_tensor = self.img_projection( - tt) # adapted visual features. - elif img_embeds.ndim == 3: - selected_g_values = g_values[::self.num_img_tokens] - tt = (img_embeds.to(target_device).to(target_dtype).view( - -1, self.image_dim_out)) - img_set_tensor = self.img_projection( - tt) # adapted visual features. - else: - raise NotImplementedError + # if self.use_hd_transform and img_sizes: + # img_embeds: (num_images, max_num_crops, 3, H, W) + # img_sizes: (num_images, 2).view(1, -1) + + bs = img_embeds.shape[0] + # Nx(HW)xC + img_features = self.get_img_features(img_embeds.flatten(0, 1)) + base_feat_height = base_feat_width = int( + img_features.shape[1]**0.5) + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view( + bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + + for _bs in range(bs): + h, w = img_sizes + h = h // 336 + w = w // 336 + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = global_img_feature \ + .reshape(1, H // 2, 2, H // 2, 2,C) \ + .permute(0, 1, 3, 2, 4, 5) \ + .reshape(1, H // 2, H // 2, 4 * C) + temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], + dim=2).reshape(1, -1, 4 * C) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \ + .permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C) + sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \ + .permute(0, 1, 3, 2, 4, 5) \ + .reshape(1, h * 12, w * 12, 4 * C) + temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) + sub_img = torch.cat([sub_img, temp_sub_GN], + dim=2).reshape(1, -1, 4 * C) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == 'glb_sub': + output_imgs.append( + torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == 'sub_glb': + output_imgs.append( + torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + + temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + output_len.append(temp_len) + + num_img_tokens = output_len + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection( + _output_img.to(target_device, target_dtype)) + img_set_tensor.append(img_feature_proj) select = True - with torch.no_grad(): - input_ids.clamp_min_(0).clamp_max_(self.vocab_size) + input_ids.clamp_min_(0).clamp_max_(self.vocab_size) hidden_states = self.wte(input_ids) - if select and hd_transform: + if select: idx = 0 for i, cnt in enumerate(num_img_tokens): hidden_states[positions[idx, 0], @@ -299,15 +251,6 @@ def forward(self, cnt] = (img_set_tensor[i].to( hidden_states.device, hidden_states.dtype)) idx += cnt - elif select: - idx = 0 - for i, g in enumerate(selected_g_values): - cnt = self.num_img_tokens - hidden_states[positions[idx, 0], - positions[idx, 1]:positions[idx, 1] + - cnt] = (img_set_tensor[i * cnt:(i + 1) * cnt].to( - hidden_states.device, hidden_states.dtype)) - idx += cnt return hidden_states.squeeze(0) @@ -328,8 +271,8 @@ def __init__(self, super().__init__(vision_language_config) self.config = config self.model = LlamaModel(config, cache_config, quant_config) - self.vision_embed_tokens = Phi3ImageEmbedding(config, - self.model.embed_tokens) + self.vision_embed_tokens = Phi3HDImageEmbedding( + config, self.model.embed_tokens) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() From 591a9b89d155679ea3bac0f67cd8a932d1e478be Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 26 May 2024 14:46:28 +0800 Subject: [PATCH 09/30] format phi3v_example --- examples/phi3v_example.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py index fb2a119797cd..678b517d42a1 100644 --- a/examples/phi3v_example.py +++ b/examples/phi3v_example.py @@ -1,28 +1,27 @@ -import requests +import os +import subprocess + from PIL import Image from transformers import AutoProcessor -from vllm import LLM +from vllm import LLM, SamplingParams from vllm.sequence import MultiModalData def run_phi3v(): - model_path = "/data/LLM-model/Phi-3-vision-128k-instruct" + model_path = "microsoft/Phi-3-vision-128k-instruct" processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) llm = LLM( model=model_path, trust_remote_code=True, - dtype='bfloat16', - max_model_len=4096, image_input_type="pixel_values", image_token_id=-1, image_input_shape="1008, 1344", image_feature_size=1024, ) - url = "https://www.ilankelman.org/stopsigns/australia.jpg" - image = Image.open(requests.get(url, stream=True).raw) + image = Image.open("images/stop_sign.jpg") user_prompt = '<|user|>\n' assistant_prompt = '<|assistant|>\n' suffix = "<|end|>\n" @@ -33,14 +32,29 @@ def run_phi3v(): inputs = processor(prompt, image, return_tensors="pt") multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=inputs["pixel_values"]) - + sampling_params = SamplingParams(temperature=0, max_tokens=64) outputs = llm.generate(prompt_token_ids=inputs["input_ids"].tolist(), + sampling_params=sampling_params, multi_modal_data=multi_modal_data) - # outputs = llm.generate(prompt) for o in outputs: generated_text = o.outputs[0].text print(generated_text) if __name__ == "__main__": + s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/" + local_directory = "images" + + # Make sure the local directory exists or create it + os.makedirs(local_directory, exist_ok=True) + + # Use AWS CLI to sync the directory, assume anonymous access + subprocess.check_call([ + "aws", + "s3", + "sync", + s3_bucket_path, + local_directory, + "--no-sign-request", + ]) run_phi3v() From 2e5dc278ac74de208619edcb6e5a4da3073beb1b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 8 Jun 2024 11:39:42 +0800 Subject: [PATCH 10/30] refactor phi3v --- examples/phi3v_example.py | 22 +++++++++++++--------- vllm/model_executor/models/phi3v.py | 12 ++++++++++-- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py index 678b517d42a1..bb169c1298e4 100644 --- a/examples/phi3v_example.py +++ b/examples/phi3v_example.py @@ -5,7 +5,7 @@ from transformers import AutoProcessor from vllm import LLM, SamplingParams -from vllm.sequence import MultiModalData +from vllm.multimodal.image import ImagePixelData def run_phi3v(): @@ -17,25 +17,29 @@ def run_phi3v(): trust_remote_code=True, image_input_type="pixel_values", image_token_id=-1, - image_input_shape="1008, 1344", + image_input_shape="1,3,1008,1344", image_feature_size=1024, + disable_image_processor=False, ) image = Image.open("images/stop_sign.jpg") - user_prompt = '<|user|>\n' - assistant_prompt = '<|assistant|>\n' + user_prompt = "<|user|>\n" + assistant_prompt = "<|assistant|>\n" suffix = "<|end|>\n" # single-image prompt prompt = "What is shown in this image?" prompt = f"{user_prompt}<|image_1|>\n{prompt}{suffix}{assistant_prompt}" + inputs = processor(prompt, image, return_tensors="pt") - multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, - data=inputs["pixel_values"]) + input_ids = inputs["input_ids"].squeeze(0).tolist() sampling_params = SamplingParams(temperature=0, max_tokens=64) - outputs = llm.generate(prompt_token_ids=inputs["input_ids"].tolist(), - sampling_params=sampling_params, - multi_modal_data=multi_modal_data) + + outputs = llm.generate({ + "prompt_token_ids": input_ids, + "sampling_params": sampling_params, + "multi_modal_data": ImagePixelData(image), + }) for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index d7ebe7106977..7f144ffc053a 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -31,6 +31,8 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.vlm_base import VisionLanguageModelBase from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import get_dummy_image_data from vllm.sequence import SamplerOutput logger = logging.get_logger(__name__) @@ -184,6 +186,9 @@ def forward(self, output_imgs = [] output_len = [] + if isinstance(img_sizes, torch.Tensor): + img_sizes.squeeze_(0) + for _bs in range(bs): h, w = img_sizes h = h // 336 @@ -261,6 +266,9 @@ class Phi3VImagePixelInputs(TypedDict): """Shape: (batch_size, num_channels, height, width)""" +@MULTIMODAL_REGISTRY.register_image_feature_input() +@MULTIMODAL_REGISTRY.register_image_pixel_input() +@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) class Phi3VForCausalLM(VisionLanguageModelBase): def __init__(self, @@ -285,8 +293,8 @@ def forward(self, image_input: Phi3VImagePixelInputs = None): if image_input is not None: inputs_embeds = self.vision_embed_tokens( - input_ids, image_input, - self.vision_language_config.image_input_shape) + input_ids, image_input["pixel_values"], + image_input["image_sizes"]) input_ids = None else: From 1b1f3f290fd7e190d497ce9c18f06232e60eaeb4 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 8 Jun 2024 11:50:13 +0800 Subject: [PATCH 11/30] remove phi3v feature inputs --- vllm/model_executor/models/phi3v.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 7f144ffc053a..73d97333dd2a 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -266,7 +266,6 @@ class Phi3VImagePixelInputs(TypedDict): """Shape: (batch_size, num_channels, height, width)""" -@MULTIMODAL_REGISTRY.register_image_feature_input() @MULTIMODAL_REGISTRY.register_image_pixel_input() @MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) class Phi3VForCausalLM(VisionLanguageModelBase): From 8db0d2599b996249ed8c0962e5b6a4b7d0bde1a2 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 11 Jun 2024 16:15:17 +0000 Subject: [PATCH 12/30] refactor phi3v --- examples/phi3v_example.py | 12 ++++-------- vllm/model_executor/models/phi3v.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py index bb169c1298e4..79e8d41aafc1 100644 --- a/examples/phi3v_example.py +++ b/examples/phi3v_example.py @@ -2,7 +2,6 @@ import subprocess from PIL import Image -from transformers import AutoProcessor from vllm import LLM, SamplingParams from vllm.multimodal.image import ImagePixelData @@ -10,13 +9,12 @@ def run_phi3v(): model_path = "microsoft/Phi-3-vision-128k-instruct" - processor = AutoProcessor.from_pretrained(model_path, - trust_remote_code=True) llm = LLM( model=model_path, trust_remote_code=True, + max_model_len=4096, image_input_type="pixel_values", - image_token_id=-1, + image_token_id=32044, image_input_shape="1,3,1008,1344", image_feature_size=1024, disable_image_processor=False, @@ -29,14 +27,12 @@ def run_phi3v(): # single-image prompt prompt = "What is shown in this image?" - prompt = f"{user_prompt}<|image_1|>\n{prompt}{suffix}{assistant_prompt}" + prompt = user_prompt+"<|image|>"*1921+f"\n{prompt}{suffix}{assistant_prompt}" - inputs = processor(prompt, image, return_tensors="pt") - input_ids = inputs["input_ids"].squeeze(0).tolist() sampling_params = SamplingParams(temperature=0, max_tokens=64) outputs = llm.generate({ - "prompt_token_ids": input_ids, + "prompt": prompt, "sampling_params": sampling_params, "multi_modal_data": ImagePixelData(image), }) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 73d97333dd2a..4cfb862a5c41 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -263,6 +263,7 @@ def forward(self, class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor + image_sizes: torch.Tensor """Shape: (batch_size, num_channels, height, width)""" @@ -284,15 +285,38 @@ def __init__(self, self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + + expected_input_type = self.vision_language_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if expected_input_type == ImageInputType.PIXEL_VALUES: + + if pixel_values is not None and image_sizes is not None: + return Phi3VImagePixelInputs( + type="pixel_values", + data=pixel_values, + image_sizes=image_sizes + ) + + return None + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - image_input: Phi3VImagePixelInputs = None): + **kwargs: object): + image_input = kwargs.pop("image_input", None) + if image_input is not None: + image_input = self._parse_and_validate_image_input(**image_input) + input_ids[input_ids==self.vision_language_config.image_token_id] = -1 inputs_embeds = self.vision_embed_tokens( - input_ids, image_input["pixel_values"], + input_ids, image_input["data"], image_input["image_sizes"]) input_ids = None From 7388bcde552f0e7ab626ed3976c78fb173704cb6 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 12 Jun 2024 19:48:31 +0800 Subject: [PATCH 13/30] deprecate phi3v image_input --- vllm/model_executor/models/phi3v.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4cfb862a5c41..bd553fe5c85b 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -310,10 +310,9 @@ def forward(self, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, **kwargs: object): - image_input = kwargs.pop("image_input", None) + image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: - image_input = self._parse_and_validate_image_input(**image_input) input_ids[input_ids==self.vision_language_config.image_token_id] = -1 inputs_embeds = self.vision_embed_tokens( input_ids, image_input["data"], From 3f3d2b8a421e37de8dd2c837e8f5b3ec538e6042 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 12 Jun 2024 23:20:37 +0800 Subject: [PATCH 14/30] add phi3v test --- tests/models/test_phi3v.py | 110 ++++++++++++++++++++++++++++ vllm/model_executor/models/phi3v.py | 15 ++-- 2 files changed, 117 insertions(+), 8 deletions(-) create mode 100644 tests/models/test_phi3v.py diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py new file mode 100644 index 000000000000..15d5d91db365 --- /dev/null +++ b/tests/models/test_phi3v.py @@ -0,0 +1,110 @@ +from typing import List, Tuple + +import pytest +from transformers import AutoTokenizer + +from vllm.config import VisionLanguageConfig + +from ..conftest import IMAGE_FILES + +# pytestmark = pytest.mark.phi3v + +# The image token is placed before "user" on purpose so that the test can pass +HF_IMAGE_PROMPTS = [ + "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", + "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", +] + +assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES) + + +def iter_phi3v_configs(model_name: str): + image_hw_to_feature_size = { + (1008, 1344): 1921, + } + + for (h, w), f in image_hw_to_feature_size.items(): + for input_type, input_shape in [ + (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), + ]: + yield (model_name, + VisionLanguageConfig(image_input_type=input_type, + image_feature_size=f, + image_token_id=32044, + image_input_shape=input_shape, + image_processor=model_name, + image_processor_revision=None)) + + +model_and_vl_config = [ + *iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"), +] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str], + vlm_config: VisionLanguageConfig, model_id: str): + """Sanitize vllm output to be comparable with hf output. + The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, + x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... + It also reduces `output_str` from "bla" to "bla". + """ + input_ids, output_str = vllm_output + image_token_id = vlm_config.image_token_id + + tokenizer = AutoTokenizer.from_pretrained(model_id) + image_token_str = tokenizer.decode(image_token_id) + + hf_input_ids = [ + input_id for idx, input_id in enumerate(input_ids) + if input_id != image_token_id or input_ids[idx - 1] != image_token_id + ] + hf_output_str = output_str \ + .replace(image_token_str * vlm_config.image_feature_size, "") + + return hf_input_ids, hf_output_str + + +# TODO: Add test for `tensor_parallel_size` [ref: PR #3883] +@pytest.mark.parametrize("model_and_config", model_and_vl_config) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models(hf_runner, vllm_runner, hf_images, vllm_images, + model_and_config, dtype: str, max_tokens: int) -> None: + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalData objects and corresponding + vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + model_id, vlm_config = model_and_config + + with hf_runner(model_id, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, + max_tokens, + images=hf_images) + + vllm_image_prompts = [ + p.replace("<|image_1|>", "<|image|>" * vlm_config.image_feature_size + "") + for p in HF_IMAGE_PROMPTS + ] + + with vllm_runner(model_id, + dtype=dtype, + max_model_len=4096, + enforce_eager=True, + **vlm_config.as_cli_args_dict()) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, + max_tokens, + images=vllm_images) + + for i in range(len(HF_IMAGE_PROMPTS)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_to_hf_output( + vllm_outputs[i], vlm_config, model_id) + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index bd553fe5c85b..792348572a7c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -293,14 +293,13 @@ def _parse_and_validate_image_input( expected_input_type = self.vision_language_config.image_input_type ImageInputType = VisionLanguageConfig.ImageInputType - if expected_input_type == ImageInputType.PIXEL_VALUES: - - if pixel_values is not None and image_sizes is not None: - return Phi3VImagePixelInputs( - type="pixel_values", - data=pixel_values, - image_sizes=image_sizes - ) + if expected_input_type != ImageInputType.PIXEL_VALUES: + return None + + if pixel_values is not None and image_sizes is not None: + return Phi3VImagePixelInputs(type="pixel_values", + data=pixel_values, + image_sizes=image_sizes) return None From 0705a626a5e2ca378b6ea55b5ebc81cd3aba8208 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 13 Jun 2024 22:22:33 +0800 Subject: [PATCH 15/30] fix phi3v test --- tests/models/test_phi3v.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 15d5d91db365..ef493183fef3 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -59,7 +59,8 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str], if input_id != image_token_id or input_ids[idx - 1] != image_token_id ] hf_output_str = output_str \ - .replace(image_token_str * vlm_config.image_feature_size, "") + .replace(image_token_str * vlm_config.image_feature_size, "") \ + .replace("", " ").replace("<|user|>", "").replace("<|end|>\n<|assistant|>", " ") return hf_input_ids, hf_output_str @@ -93,7 +94,6 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images, with vllm_runner(model_id, dtype=dtype, - max_model_len=4096, enforce_eager=True, **vlm_config.as_cli_args_dict()) as vllm_model: vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, From 5f21d3ffbcce23a986cfc30781794c2011531e61 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 13 Jun 2024 22:31:52 +0800 Subject: [PATCH 16/30] format code --- examples/phi3v_example.py | 11 ++++------- tests/models/test_phi3v.py | 16 ++++++++-------- vllm/model_executor/models/phi3v.py | 13 +++++-------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py index 79e8d41aafc1..d5e60ae1ee3a 100644 --- a/examples/phi3v_example.py +++ b/examples/phi3v_example.py @@ -16,18 +16,15 @@ def run_phi3v(): image_input_type="pixel_values", image_token_id=32044, image_input_shape="1,3,1008,1344", - image_feature_size=1024, + image_feature_size=1921, disable_image_processor=False, ) - image = Image.open("images/stop_sign.jpg") - user_prompt = "<|user|>\n" - assistant_prompt = "<|assistant|>\n" - suffix = "<|end|>\n" + image = Image.open("images/cherry_blossom.jpg") # single-image prompt - prompt = "What is shown in this image?" - prompt = user_prompt+"<|image|>"*1921+f"\n{prompt}{suffix}{assistant_prompt}" + prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501 + prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "") sampling_params = SamplingParams(temperature=0, max_tokens=64) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index ef493183fef3..9286ea1aa701 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -7,11 +7,9 @@ from ..conftest import IMAGE_FILES -# pytestmark = pytest.mark.phi3v - # The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = [ - "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", + "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", ] @@ -28,7 +26,7 @@ def iter_phi3v_configs(model_name: str): (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), ]: yield (model_name, - VisionLanguageConfig(image_input_type=input_type, + VisionLanguageConfig(image_input_type=input_type, image_feature_size=f, image_token_id=32044, image_input_shape=input_shape, @@ -60,7 +58,8 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str], ] hf_output_str = output_str \ .replace(image_token_str * vlm_config.image_feature_size, "") \ - .replace("", " ").replace("<|user|>", "").replace("<|end|>\n<|assistant|>", " ") + .replace("", " ").replace("<|user|>", "") \ + .replace("<|end|>\n<|assistant|>", " ") return hf_input_ids, hf_output_str @@ -84,11 +83,12 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images, with hf_runner(model_id, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, - max_tokens, - images=hf_images) + max_tokens, + images=hf_images) vllm_image_prompts = [ - p.replace("<|image_1|>", "<|image|>" * vlm_config.image_feature_size + "") + p.replace("<|image_1|>", + "<|image|>" * vlm_config.image_feature_size + "") for p in HF_IMAGE_PROMPTS ] diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 792348572a7c..51a0bd11263d 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -303,19 +303,16 @@ def _parse_and_validate_image_input( return None - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - **kwargs: object): + attn_metadata: AttentionMetadata, **kwargs: object): image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: - input_ids[input_ids==self.vision_language_config.image_token_id] = -1 + input_ids[input_ids == + self.vision_language_config.image_token_id] = -1 inputs_embeds = self.vision_embed_tokens( - input_ids, image_input["data"], - image_input["image_sizes"]) + input_ids, image_input["data"], image_input["image_sizes"]) input_ids = None else: From 3b7f86afa581aeddaf8ad90f75a0a46a00920c7b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 14 Jun 2024 00:12:28 +0800 Subject: [PATCH 17/30] add phi3_v to get_full_image_text_prompt and test marker --- tests/models/test_phi3v.py | 2 ++ vllm/multimodal/utils.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 9286ea1aa701..7576771eef39 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -7,6 +7,8 @@ from ..conftest import IMAGE_FILES +pytestmark = pytest.mark.llava + # The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = [ "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index c6311d60e0bd..509f791d27c6 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -79,6 +79,8 @@ def get_full_image_text_prompt(image_prompt: str, text_prompt: str, if config.hf_config.model_type in ("llava", "llava_next"): full_prompt = f"{image_prompt}\n{text_prompt}" + elif config.hf_config.model_type == 'phi3_v': + full_prompt = f"{image_prompt}\n{text_prompt}" else: raise ValueError( f"Unsupported model type: {config.hf_config.model_type}") From ced2c3d5c7e2d5a8ca1eb147d0da6ce799b37f84 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 14 Jun 2024 00:26:14 +0800 Subject: [PATCH 18/30] add docs --- docs/source/models/supported_models.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 5d3f55be1271..f4673dc27092 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -135,6 +135,10 @@ Alongside each architecture, we include some popular models that use it. - Phi-3-Small - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. - + * - :code:`Phi3VForCausalLM` + - Phi-3-Vision + - :code:`microsoft/Phi-3-vision-128k-instruct`, etc. + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. From 1d82bb1c8ec6a4a05e1a8ca25338ab5c78310fcd Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 14 Jun 2024 22:48:24 +0800 Subject: [PATCH 19/30] clear phi3v model implement --- vllm/model_executor/models/phi3v.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 51a0bd11263d..2f7c42860a1c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -120,9 +120,6 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None: self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4])) self.sub_GN = nn.Parameter( torch.empty([1, 1, 1, self.image_dim_out * 4])) - logger.info( - 'learnable separator enabled for hd transform' - 'hd_transform_order = %s', self.hd_transform_order) dim_projection = hidden_size depth = 2 @@ -143,6 +140,7 @@ def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None) -> torch.FloatTensor: + "process and merge text embeddings with image embeddings." MAX_INPUT_ID = int(1e9) img_embeds = pixel_values @@ -294,7 +292,9 @@ def _parse_and_validate_image_input( ImageInputType = VisionLanguageConfig.ImageInputType if expected_input_type != ImageInputType.PIXEL_VALUES: - return None + raise ValueError( + f"Unexpected image input type: {expected_input_type}." + "Phi3v only support pixel_values input currently.") if pixel_values is not None and image_sizes is not None: return Phi3VImagePixelInputs(type="pixel_values", From 4739c451d0c1945e2aa94c967a3f8eee17ba7607 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 14 Jun 2024 22:53:20 +0800 Subject: [PATCH 20/30] ignore phi3v cpu test --- .buildkite/run-cpu-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 6a86bc0ebfb6..f5b8a3fb4cb6 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -21,4 +21,4 @@ docker exec cpu-test bash -c "cd tests; pip install pytest Pillow protobuf bash ../.buildkite/download-images.sh cd ../ - pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" + pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_phi3v.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" From 59fe2c1d505f2cb41ed356e6946527f6adfef6f6 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 14 Jun 2024 23:01:29 +0800 Subject: [PATCH 21/30] fix doc strings --- vllm/model_executor/models/phi3v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 2f7c42860a1c..5b5f33d73a66 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -140,7 +140,7 @@ def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None) -> torch.FloatTensor: - "process and merge text embeddings with image embeddings." + """process and merge text embeddings with image embeddings.""" MAX_INPUT_ID = int(1e9) img_embeds = pixel_values From 38ed4d90465f56aedb22ed46e50e96251e519994 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 15 Jun 2024 00:11:16 +0800 Subject: [PATCH 22/30] fix phi3v test flash_attn import --- tests/conftest.py | 3 +++ tests/models/test_phi3v.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 29a4f126ff92..5652a359a227 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -146,6 +146,7 @@ def __init__( model_name: str, dtype: str = "half", *, + model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, is_vision_model: bool = False, ) -> None: @@ -168,11 +169,13 @@ def __init__( else: auto_cls = AutoModelForCausalLM + model_kwargs = model_kwargs if model_kwargs is not None else {} self.model = self.wrap_device( auto_cls.from_pretrained( model_name, torch_dtype=torch_dtype, trust_remote_code=True, + **model_kwargs, )) self.tokenizer = AutoTokenizer.from_pretrained( diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 7576771eef39..ccb9cdcb26ce 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -37,7 +37,8 @@ def iter_phi3v_configs(model_name: str): model_and_vl_config = [ - *iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"), + # *iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"), + *iter_phi3v_configs("/data/LLM-model/Phi-3-vision-128k-instruct"), ] @@ -83,7 +84,10 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images, """ model_id, vlm_config = model_and_config - with hf_runner(model_id, dtype=dtype) as hf_model: + # use eager mode for hf runner, since phi3_v didn't work with flash_attn + hf_model_kwargs = {"_attn_implementation": "eager"} + with hf_runner(model_id, dtype=dtype, + model_kwargs=hf_model_kwargs) as hf_model: hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, max_tokens, images=hf_images) From d2fbecf27f3e971cf5f825fc8b08c72d5e17667b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 15 Jun 2024 00:40:26 +0800 Subject: [PATCH 23/30] fix phi3v test --- tests/models/test_phi3v.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index ccb9cdcb26ce..fbd81c5b59e8 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -37,8 +37,7 @@ def iter_phi3v_configs(model_name: str): model_and_vl_config = [ - # *iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"), - *iter_phi3v_configs("/data/LLM-model/Phi-3-vision-128k-instruct"), + *iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"), ] From 2bbaecd34824e43dc2ae8638761c8735aca42f77 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 15 Jun 2024 01:45:45 +0800 Subject: [PATCH 24/30] add torchvision to requirements-test.txt --- requirements-test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-test.txt b/requirements-test.txt index 8b68e0e93966..d15c10f335e0 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -14,6 +14,7 @@ peft requests ray sentence-transformers # required for embedding +torchvision==0.18.0 # required for phi3v # Benchmarking aiohttp From ce62fadf0d13e649e66d9f7f0c96a5dcb4b3b09f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 15 Jun 2024 10:09:27 +0800 Subject: [PATCH 25/30] increase phi3v max_model_len to 2048 --- tests/models/test_phi3v.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index fbd81c5b59e8..b5f15ff104d3 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -98,6 +98,7 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images, ] with vllm_runner(model_id, + max_model_len=2048, dtype=dtype, enforce_eager=True, **vlm_config.as_cli_args_dict()) as vllm_model: From 1d785908769ad2e0a13b47b15983620420d9fad5 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 15 Jun 2024 14:42:56 +0800 Subject: [PATCH 26/30] decrease phi3v max_tokens to 8 --- tests/models/test_phi3v.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index b5f15ff104d3..607ad95e8c36 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -4,6 +4,7 @@ from transformers import AutoTokenizer from vllm.config import VisionLanguageConfig +from vllm.utils import is_cpu from ..conftest import IMAGE_FILES @@ -55,8 +56,8 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str], image_token_str = tokenizer.decode(image_token_id) hf_input_ids = [ - input_id for idx, input_id in enumerate(input_ids) - if input_id != image_token_id or input_ids[idx - 1] != image_token_id + input_id if input_id != image_token_id else 0 + for idx, input_id in enumerate(input_ids) ] hf_output_str = output_str \ .replace(image_token_str * vlm_config.image_feature_size, "") \ @@ -66,10 +67,17 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str], return hf_input_ids, hf_output_str +target_dtype = "half" +if is_cpu(): + target_dtype = "bfloat16" + + # TODO: Add test for `tensor_parallel_size` [ref: PR #3883] +# Since we use _attn_implementation="eager" for hf_runner, here is +# numeric difference for longer context and test can't pass @pytest.mark.parametrize("model_and_config", model_and_vl_config) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [8]) def test_models(hf_runner, vllm_runner, hf_images, vllm_images, model_and_config, dtype: str, max_tokens: int) -> None: """Inference result should be the same between hf and vllm. From da1392ce0b626419bb2c6d8fa21c209858ee4b7d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 17 Jun 2024 22:24:20 +0800 Subject: [PATCH 27/30] optimize image embedding and update requirements.txt --- requirements-test.txt | 2 +- vllm/model_executor/models/phi3v.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index d15c10f335e0..c827e66751f4 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -14,7 +14,7 @@ peft requests ray sentence-transformers # required for embedding -torchvision==0.18.0 # required for phi3v +torchvision # required for phi3v # Benchmarking aiohttp diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 5b5f33d73a66..d9d5e51112fb 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -91,9 +91,13 @@ def get_img_features(self, class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): """Phi3 Image embedding with HD transform.""" - def __init__(self, config: PretrainedConfig, wte=None) -> None: + def __init__(self, + vision_language_config: VisionLanguageConfig, + config: PretrainedConfig, + wte=None) -> None: super().__init__(wte) + self.image_token_id = vision_language_config.image_token_id # n_embed or hidden_size hidden_size = config.n_embd if hasattr( config, 'n_embd') else config.hidden_size @@ -142,7 +146,6 @@ def forward(self, image_sizes=None) -> torch.FloatTensor: """process and merge text embeddings with image embeddings.""" - MAX_INPUT_ID = int(1e9) img_embeds = pixel_values img_sizes = image_sizes @@ -156,8 +159,7 @@ def forward(self, input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - positions = torch.nonzero( - (input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False) + positions = input_ids == self.image_token_id select = False @@ -261,8 +263,10 @@ def forward(self, class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor + """Shape: (batch_size, 1 + num_patches, num_channels, height, width)""" + image_sizes: torch.Tensor - """Shape: (batch_size, num_channels, height, width)""" + """Shape: (batch_size, 2)""" @MULTIMODAL_REGISTRY.register_image_pixel_input() @@ -278,7 +282,7 @@ def __init__(self, self.config = config self.model = LlamaModel(config, cache_config, quant_config) self.vision_embed_tokens = Phi3HDImageEmbedding( - config, self.model.embed_tokens) + vision_language_config, config, self.model.embed_tokens) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() From 9c080becb228939b49b7ef2a3e28c4b6c882a26d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 17 Jun 2024 22:25:48 +0800 Subject: [PATCH 28/30] remove changing input_ids to -1 --- vllm/model_executor/models/phi3v.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index d9d5e51112fb..27069198966b 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -313,8 +313,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: - input_ids[input_ids == - self.vision_language_config.image_token_id] = -1 inputs_embeds = self.vision_embed_tokens( input_ids, image_input["data"], image_input["image_sizes"]) From 0cd9d267073e35c5566405f3a9600de53728d150 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 17 Jun 2024 22:51:36 +0800 Subject: [PATCH 29/30] fix a typo and image embedding --- vllm/model_executor/models/phi3v.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 27069198966b..e8f190d3fc4f 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -118,7 +118,7 @@ def __init__(self, self.hd_transform_order = config.embd_layer.get( 'hd_transform_order', 'glb_sub') # with_hd_transform and with_learnable_separator should have same value - assert self.use_hd_transform & self.with_learnable_separator + assert self.use_hd_transform and self.with_learnable_separator # 1024 * 4, merge spatial to channel dimension self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4])) @@ -159,7 +159,7 @@ def forward(self, input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - positions = input_ids == self.image_token_id + positions = torch.nonzero(input_ids == self.image_token_id) select = False From e77bb7695d927e6ce5ac0db330fd660627362afd Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 17 Jun 2024 22:55:55 +0800 Subject: [PATCH 30/30] update comment for phi3v --- requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index c827e66751f4..fef0ede7be0f 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -14,7 +14,7 @@ peft requests ray sentence-transformers # required for embedding -torchvision # required for phi3v +torchvision # required for the image processor of phi3v # Benchmarking aiohttp