Skip to content

Commit

Permalink
[CI/Build] Add TP test for vision models (vllm-project#5892)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and jimpang committed Jul 8, 2024
1 parent 7d508d8 commit 5871f39
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 27 deletions.
5 changes: 5 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ steps:
working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands:
- bash ../.buildkite/download-images.sh
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
Expand All @@ -52,10 +53,14 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
Expand Down
51 changes: 51 additions & 0 deletions tests/distributed/test_multimodal_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
The second test will hang if more than one test is run per command, so we need
to run the tests one by one. The solution is to pass arguments (model name) by
environment variables.
Run:
```sh
TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf \
test_multimodal_broadcast.py
TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct \
test_multimodal_broadcast.py
```
"""
import os

import pytest

from vllm.utils import cuda_device_count_stateless

model = os.environ["TEST_DIST_MODEL"]

if model.startswith("llava-hf/llava"):
from ..models.test_llava import model_and_vl_config, run_test
elif model.startswith("microsoft/Phi-3-vision"):
from ..models.test_phi3v import model_and_vl_config, run_test
else:
raise NotImplementedError(f"Unsupported model: {model}")


@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets,
tensor_parallel_size: int, dtype: str,
max_tokens: int) -> None:
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(
f"Need at least {tensor_parallel_size} GPUs to run the test.")

distributed_executor_backend = os.getenv("DISTRIBUTED_EXECUTOR_BACKEND")

run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config=model_and_vl_config[0],
dtype=dtype,
max_tokens=max_tokens,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
)
39 changes: 31 additions & 8 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List, Tuple
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer

from vllm.config import VisionLanguageConfig

from ..conftest import IMAGE_ASSETS
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets

pytestmark = pytest.mark.vlm

Expand Down Expand Up @@ -65,12 +65,17 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
return hf_output_ids, hf_output_str


# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig],
*,
dtype: str,
max_tokens: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
Expand All @@ -96,6 +101,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,

with vllm_runner(model_id,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
Expand All @@ -110,3 +117,19 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config,
dtype=dtype,
max_tokens=max_tokens,
tensor_parallel_size=1,
)
49 changes: 36 additions & 13 deletions tests/models/test_phi3v.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List, Tuple
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer

from vllm.config import VisionLanguageConfig
from vllm.utils import is_cpu

from ..conftest import IMAGE_ASSETS
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets

pytestmark = pytest.mark.vlm

Expand Down Expand Up @@ -73,17 +73,17 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
target_dtype = "bfloat16"


# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
# Since we use _attn_implementation="eager" for hf_runner, here is
# numeric difference for longer context and test can't pass
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig],
*,
dtype: str,
max_tokens: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
Expand Down Expand Up @@ -116,7 +116,9 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
with vllm_runner(model_id,
max_model_len=2048,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=True,
distributed_executor_backend=distributed_executor_backend,
**vlm_config.as_cli_args_dict()) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
Expand All @@ -130,3 +132,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


# Since we use _attn_implementation="eager" for hf_runner, here is
# numeric difference for longer context and test can't pass
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config,
dtype=dtype,
max_tokens=max_tokens,
tensor_parallel_size=1,
)
1 change: 1 addition & 0 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def broadcast_object(self, obj=None):
else:
return self.dequeue()

@staticmethod
def create_from_process_group(pg: ProcessGroup,
max_chunk_bytes,
max_chunks,
Expand Down
4 changes: 3 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__(
self.shm_broadcaster: Optional[ShmRingBufferIO] = None
if self.world_size > 1 and is_in_the_same_node(self.cpu_group):
self.shm_broadcaster = ShmRingBufferIO.create_from_process_group(
self.cpu_group, 1 << 20, 6)
self.cpu_group, 1 << 22, 6)

@property
def first_rank(self):
Expand Down Expand Up @@ -690,6 +690,8 @@ def destroy(self):
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.shm_broadcaster is not None:
self.shm_broadcaster = None


_WORLD: Optional[GroupCoordinator] = None
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
image_features = vision_tower(pixel_values,
self.config.vision_feature_layer)

return self._select_image_features(
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
image_features = vision_tower(pixel_values,
self.config.vision_feature_layer)

return self._select_image_features(
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def forward(self, input_ids: torch.LongTensor,

select = False

target_device = self.img_projection[0].bias.device
target_dtype = self.img_projection[0].bias.dtype

if len(positions.tolist()) > 0:
Expand Down Expand Up @@ -231,7 +230,7 @@ def forward(self, input_ids: torch.LongTensor,
img_set_tensor = []
for _output_img in output_imgs:
img_feature_proj = self.img_projection(
_output_img.to(target_device, target_dtype))
_output_img.to(target_dtype))
img_set_tensor.append(img_feature_proj)
select = True

Expand All @@ -245,7 +244,7 @@ def forward(self, input_ids: torch.LongTensor,
hidden_states[positions[idx, 0],
positions[idx, 1]:positions[idx, 1] +
cnt] = (img_set_tensor[i].to(
hidden_states.device, hidden_states.dtype))
hidden_states.dtype))
idx += cnt

return hidden_states.squeeze(0)
Expand Down

0 comments on commit 5871f39

Please sign in to comment.