Skip to content

Commit

Permalink
[VLM] Use SequenceData.from_token_counts to create dummy data (vllm…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and siddharth9820 committed Sep 30, 2024
1 parent e5c764d commit 3f4b1bb
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 80 deletions.
2 changes: 1 addition & 1 deletion vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _default_dummy_data_factory(
# Avoid circular import
from vllm.sequence import SequenceData

dummy_seq_data = SequenceData.from_counts({0: seq_len})
dummy_seq_data = SequenceData.from_token_counts((0, seq_len))
dummy_multi_modal_data = None

return dummy_seq_data, dummy_multi_modal_data
Expand Down
13 changes: 6 additions & 7 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Optional, Union

import torch
Expand All @@ -19,7 +18,7 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.sequence import SequenceData

try:
from xformers import ops as xops
Expand Down Expand Up @@ -53,6 +52,7 @@ def get_max_blip_image_tokens(
def dummy_seq_data_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
Expand All @@ -62,11 +62,10 @@ def dummy_seq_data_for_blip(
else:
image_feature_size = image_feature_size_override

token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)


def dummy_image_for_blip(
Expand Down
13 changes: 5 additions & 8 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)

Expand All @@ -18,8 +17,7 @@
from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData

from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
Expand Down Expand Up @@ -429,11 +427,10 @@ def dummy_seq_data_for_blip2(
else:
image_feature_size = image_feature_size_override

token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)


def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
Expand Down
13 changes: 5 additions & 8 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from array import array
from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict)
Expand Down Expand Up @@ -32,8 +31,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import print_warning_once

from .interfaces import SupportsMultiModal
Expand Down Expand Up @@ -72,11 +70,10 @@ def dummy_seq_data_for_chameleon(
else:
image_feature_size = image_feature_size_override

token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)


def dummy_image_for_chameleon(
Expand Down
12 changes: 5 additions & 7 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Iterable, List, Optional, Tuple, Union

import torch
Expand All @@ -20,7 +19,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.sequence import SequenceData

try:
from xformers import ops as xops
Expand Down Expand Up @@ -62,11 +61,10 @@ def dummy_seq_data_for_clip(
else:
image_feature_size = image_feature_size_override

token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)


def dummy_image_for_clip(
Expand Down
7 changes: 2 additions & 5 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from array import array
from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict)
Expand Down Expand Up @@ -56,8 +55,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData

from .idefics2_vision_model import Idefics2VisionTransformer

Expand Down Expand Up @@ -259,8 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):


def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len
return SequenceData(token_ids)
return SequenceData.from_token_counts((0, seq_len))


def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
Expand Down
14 changes: 5 additions & 9 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from array import array
from dataclasses import dataclass, fields
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Tuple, Union
Expand All @@ -24,8 +23,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData

from .interfaces import SupportsMultiModal
from .utils import init_vllm_registered_model
Expand Down Expand Up @@ -63,13 +61,11 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
image_feature_size = (size**2) // (patch_size**2)

num_image_tokens = image_feature_size * num_images
seq_data = SequenceData.from_token_counts(
(image_token_id, num_image_tokens),
(0, seq_len - num_image_tokens),
)

token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * num_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - num_image_tokens)

seq_data = SequenceData(token_ids)
mm_data = {"image": num_images * [image]}
return seq_data, mm_data

Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import math
import re
from array import array
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, TypedDict, Union)
Expand Down Expand Up @@ -45,8 +44,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of

from .utils import flatten_bn, is_pp_missing_parameter, make_layers
Expand Down Expand Up @@ -819,7 +817,7 @@ def dummy_data_for_qwen(
# The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes.
if not hasattr(hf_config, "visual"):
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len))
seq_data = SequenceData.from_token_counts((0, seq_len))
mm_data = None
return seq_data, mm_data

Expand All @@ -846,11 +844,13 @@ def dummy_data_for_qwen(
if len(toks) < seq_len:
toks += [0] * (seq_len - len(toks))

seq_data = SequenceData.from_seqs(toks)

# Build the input images; width/height doesn't actually matter here since
# the data will get resized and the # of tokens per image is constant
image = Image.new("RGB", (224, 224), color=0)
mm_data = {"image": image if num_images == 1 else [image] * num_images}
return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data
return seq_data, mm_data


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
Expand Down
21 changes: 9 additions & 12 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from array import array
from functools import lru_cache, partial
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
Union)
Expand Down Expand Up @@ -66,8 +65,7 @@
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import get_processor

logger = init_logger(__name__)
Expand Down Expand Up @@ -681,15 +679,14 @@ def dummy_data_for_qwen2_vl(
"--limit-mm-per-prompt.")

hf_config = ctx.get_hf_config(Qwen2VLConfig)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_start_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.image_token_id]) * max_llm_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_end_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - max_llm_image_tokens - 2)
dummy_seqdata = SequenceData(token_ids)

dummy_seqdata = SequenceData.from_token_counts(
(hf_config.vision_start_token_id, 1),
(hf_config.image_token_id, max_llm_image_tokens),
(hf_config.vision_end_token_id, 1),
(0, seq_len - max_llm_image_tokens - 2),
)

dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0)

Expand Down
12 changes: 5 additions & 7 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
within a vision language model."""

import math
from array import array
from typing import Iterable, List, Optional, Tuple, Union

import torch
Expand All @@ -24,7 +23,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.sequence import SequenceData

try:
from xformers import ops as xops
Expand Down Expand Up @@ -67,11 +66,10 @@ def dummy_seq_data_for_siglip(
else:
image_feature_size = image_feature_size_override

token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)


def dummy_image_for_siglip(
Expand Down
30 changes: 22 additions & 8 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,11 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)


def dummy_data_for_ultravox(
def dummy_seq_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)

audio_count = mm_counts["audio"]

audio_placeholder = array(
VLLM_TOKEN_ID_ARRAY_TYPE,
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
Expand All @@ -96,10 +92,28 @@ def dummy_data_for_ultravox(
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - len(audio_token_ids))

return SequenceData(audio_token_ids + other_token_ids)


def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
mm_dict = {"audio": [audio_and_sr] * audio_count}
return {"audio": [audio_and_sr] * audio_count}


def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)

return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
return (seq_data, mm_dict)


def input_mapper_for_ultravox(ctx: InputContext, data: object):
Expand Down
6 changes: 3 additions & 3 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,13 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta: Optional[int] = None

@staticmethod
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
if len(counts_by_token) == 0:
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
if len(token_counts) == 0:
return SequenceData.from_seqs([])

arrs = [
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
for token_id, count in counts_by_token.items()
for token_id, count in token_counts
]

return SequenceData(reduce(array.__add__, arrs))
Expand Down

0 comments on commit 3f4b1bb

Please sign in to comment.