Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add GLM-4v support #5358

Closed
wants to merge 19 commits into from
Closed

[Model] Add GLM-4v support #5358

wants to merge 19 commits into from

Conversation

songxxzp
Copy link

@songxxzp songxxzp commented Jun 8, 2024

Overview

This PR support the glm-4v-9b model while maintaining compatibility with chatglm.

FIX #5417
FIX #6097

Changes

  1. Add vision_config for ChatGLMConfig
  2. Add glm4 vision encoder in vllm/model_executor/models/glm4_vision_encoder.py.
  3. Add optional vision module for ChatGLMModel, making ChatGLMForCausalLM multimodal capable.
  4. Fixed the logic for vision_language_config to ensure proper configuration of vision-language models when lora_config is present. Already fixed by [Model] Add base class for LoRA-supported models #5018
  5. Support custom position_ids (glm-4v use the same position for images tokens).
About custom `position_ids`

glm-4v use the same position for images tokens:

# query = 'Describe these two picture.'
# placeholder_id can be used freely, the image is detected by boi(151339) and eoi(151340) token.
input_ids = [151331, 151333, 151336, 198, 151339, placeholder_id, placeholder_id, ..., placeholder_id, 151340, 74198, 1493, 1378, 6802, 13, 151337]
position_ids = [0, 1, 2, 3, 4, 5, 5, ..., 5, 6, 7, 8, 9, 10, 11, 12]

Therefore, we need to maintain position_ids in SequenceData to support that.

Major Code Changes

Maintain position_ids in SequenceData(Only when position_ids is passed):
vllm/sequence.py

class SequenceData:

    def __init__(
        self,
        prompt_token_ids: List[int],
        output_token_ids: Optional[List[int]] = None,
        position_ids: Optional[List[int]] = None,
    ) -> None:
        self._position_ids: Optional[List[int]] = list(
            position_ids) if position_ids is not None else None
        ...

    def append_token_id(self, token_id: int, logprob: float) -> None:
        self._output_token_ids.append(token_id)
        self._cached_all_token_ids.append(token_id)
        self.cumulative_logprob += logprob
        if self._position_ids is not None:
            self._position_ids.append(self._position_ids[-1] + 1)
    ...

Calculate position_ids for glm-4v:
vllm/model_executor/models/chatglm.py

def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = hf_config.vision_config

    if vision_config is None:
        return llm_inputs
    elif isinstance(vision_config, dict):
        image_placeholder_length = (vision_config["image_size"] //
                                    vision_config["patch_size"] //
                                    2)**2  # 1600
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    input_ids = llm_inputs.get("prompt_token_ids")
    position_ids = llm_inputs.get("position_ids")
    if position_ids is None:
        position_ids = list(range(len(input_ids)))
    boi_token_id = hf_config.boi_token_id
    eoi_token_id = hf_config.eoi_token_id
    boi_positions = find_all_positions(input_ids, boi_token_id)
    eoi_positions = find_all_positions(input_ids, eoi_token_id)

    assert len(boi_positions) == len(eoi_positions)

    new_input_ids = []
    new_position_ids = []
    final_processed_position = 0
    final_processed_position = 0

    for boi_position, eoi_position in zip(boi_positions, eoi_positions):
        assert boi_position < eoi_position
        new_input_ids.extend(input_ids[final_processed_position:boi_position +
                                       1])
        new_position_ids.extend(
            list(range(final_processed_position, boi_position + 1)))
        new_input_ids.extend([input_ids[boi_position + 1]] *
                             image_placeholder_length)
        new_position_ids.extend([boi_position + 1] * image_placeholder_length)
        final_processed_position = eoi_position

    new_input_ids.extend(input_ids[final_processed_position:])
    new_position_ids.extend(
        list(range(final_processed_position, len(input_ids))))

    assert len(new_input_ids) == len(new_position_ids)

    llm_inputs["prompt_token_ids"] = new_input_ids
    llm_inputs["position_ids"] = new_position_ids
    return llm_inputs

Use custom position_ids in model runner:
vllm/worker/model_runner.py

            # input_positions.extend(list(range(computed_len, seq_len)))
            seq_position_ids = seq_data.get_position_ids()
            if seq_position_ids is not None:
                input_positions.extend(
                    list(seq_position_ids[computed_len:seq_len]))
            else:
                input_positions.extend(list(range(computed_len, seq_len)))
About the bugfix

Code Changes

The previous code used an elif statement that prevented the check for subclasses of VisionLanguageModelBase when lora_config was set. This has been updated to use an if statement to ensure that vision_language_config is processed correctly regardless of whether lora_config is present.

Before

In _get_model_initialization_kwargs of vllm/model_executor/model_loader/loader.py.

    if hasattr(model_class, "supported_lora_modules"):
        ...
    elif lora_config:
        ...
    elif issubclass(model_class, VisionLanguageModelBase):
        if vision_language_config is None:
            raise ValueError("Provide `image_input_type` and other vision "
                             "related configurations through LLM entrypoint "
                             "or engine arguments.")

        extra_kwargs["vision_language_config"] = vision_language_config

After

    if hasattr(model_class, "supported_lora_modules"):
        ...
    elif lora_config:
        ...
    if issubclass(model_class, VisionLanguageModelBase):
        if vision_language_config is None:
            raise ValueError("Provide `image_input_type` and other vision "
                             "related configurations through LLM entrypoint "
                             "or engine arguments.")

        extra_kwargs["vision_language_config"] = vision_language_config

Usage

glm-4-9b-chat

from PIL import Image
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.inputs import TokensPrompt


max_model_len, tp_size = 8192, 1
model_name = "THUDM/glm-4-9b-chat"

llm = LLM(
    model=model_name,
    tensor_parallel_size=tp_size,
    max_model_len=max_model_len,
    trust_remote_code=True,
    enforce_eager=True
)
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(temperature=0, max_tokens=1024, stop_token_ids=stop_token_ids)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

query = 'Hi!'
inputs = tokenizer.apply_chat_template(
    [{"role": "user", "content": query}],
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
    return_dict=True
)

input_ids = inputs['input_ids'][0].tolist()

outputs = llm.generate(
    TokensPrompt(**{
        "prompt_token_ids": input_ids,
    }),
    sampling_params=sampling_params
)

print(outputs[0].outputs[0].text)
Hi 👋! I'm ChatGLM, the artificial intelligence assistant, nice to meet you. Feel free to ask me any questions.

glm-4v-9b

from PIL import Image
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.inputs import TokensPrompt


max_model_len, tp_size = 8192, 1
model_name = "THUDM/glm-4v-9b"

llm = LLM(
    model=model_name,
    tensor_parallel_size=tp_size,
    max_model_len=max_model_len,
    trust_remote_code=True,
    enforce_eager=True
)
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(temperature=0, max_tokens=1024, stop_token_ids=stop_token_ids)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

query = 'Describe this picture.'
image = Image.open("docs/source/assets/logos/vllm-logo-text-light.png").convert('RGB')
inputs = tokenizer.apply_chat_template(
    [{"role": "user", "image": image, "content": query}],
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
    return_dict=True
)

image_tensor = inputs['images']

input_ids = inputs['input_ids'][0].tolist()

outputs = llm.generate(
    TokensPrompt(**{
        "prompt_token_ids": input_ids,
        "multi_modal_data":  {"image": image_tensor},
    }),
    sampling_params=sampling_params
)

print(outputs[0].outputs[0].text)
The image shows a logo with the letters "LLM" in uppercase, bold font. The "L" and "M" are in a dark grey or black color, while the "L" also has a slight shadow effect, giving it a three-dimensional appearance. The "L" on the left side of the logo is unique; it is stylized with a blue and a yellow shape that resembles a flag or a small arrow pointing upwards, with the blue shape being the larger and the yellow shape being the smaller, triangular extension on the right side of the blue shape. The background of the logo is a solid, dark color, which contrasts sharply with the lighter colors of the "L" and the blue and yellow shapes.

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@ywang96 ywang96 self-assigned this Jun 8, 2024
@DarkLight1337
Copy link
Member

Additionally, it addresses an issue with the handling of vision_language_config when used in conjunction with lora_config.

This may also be solved by #5018.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jun 11, 2024

Thanks for implementing this model! To improve performance, you should try to use vLLM layers instead of the default PyTorch implementations (see here).

Also, can you add a test to ensure the model's consistency with its HuggingFace version (similar to the one for LLaVA).

@HuggingAha
Copy link

Overview

This PR support the glm-4v-9b model while maintaining compatibility with chatglm. Additionally, it addresses an issue with the handling of vision_language_config when used in conjunction with lora_config.

FIX #5417

Changes

  1. Add vision_config for ChatGLMConfig
  2. Add glm4 vision encoder in vllm/model_executor/models/glm4_vision_encoder.py.
  3. Add optional vision module for ChatGLMModel, making ChatGLMForCausalLM multimodal capable.
  4. Fixed the logic for vision_language_config to ensure proper configuration of vision-language models when lora_config is present.

Code Changes

The previous code used an elif statement that prevented the check for subclasses of VisionLanguageModelBase when lora_config was set. This has been updated to use an if statement to ensure that vision_language_config is processed correctly regardless of whether lora_config is present.

Before

In _get_model_initialization_kwargs of vllm/model_executor/model_loader/loader.py.

    if hasattr(model_class, "supported_lora_modules"):
        ...
    elif lora_config:
        ...
    elif issubclass(model_class, VisionLanguageModelBase):
        if vision_language_config is None:
            raise ValueError("Provide `image_input_type` and other vision "
                             "related configurations through LLM entrypoint "
                             "or engine arguments.")

        extra_kwargs["vision_language_config"] = vision_language_config

After

    if hasattr(model_class, "supported_lora_modules"):
        ...
    elif lora_config:
        ...
    if issubclass(model_class, VisionLanguageModelBase):
        if vision_language_config is None:
            raise ValueError("Provide `image_input_type` and other vision "
                             "related configurations through LLM entrypoint "
                             "or engine arguments.")

        extra_kwargs["vision_language_config"] = vision_language_config

Usage

glm-4-9b-chat

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

max_model_len, tp_size = 8192, 1
model_name = "THUDM/glm-4-9b-chat"

boi_token_id = 151339
eoi_token_id = 151340

llm = LLM(
    model=model_name,
    tensor_parallel_size=tp_size,
    max_model_len=max_model_len,
    trust_remote_code=True,
    enforce_eager=True,
    image_input_type="pixel_values",
    image_token_id=boi_token_id,
    image_input_shape="1,3,1120,1120",
    image_feature_size=1602,
    disable_image_processor=True
)
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(temperature=0, max_tokens=1024, stop_token_ids=stop_token_ids)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

query = 'Hi!'
inputs = tokenizer.apply_chat_template(
    [{"role": "user", "content": query}],
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
    return_dict=True
)

input_ids = inputs['input_ids'][0].tolist()

outputs = llm.generate(
    {
        "prompt_token_ids": input_ids
    },
    sampling_params=sampling_params
)

print(outputs[0].outputs[0].text)
Hello 👋! I'm ChatGLM, the artificial intelligence assistant, nice to meet you. Feel free to ask me any questions.

glm-4v-9b

from PIL import Image
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImagePixelData


max_model_len, tp_size = 8192, 1
model_name = "THUDM/glm-4v-9b"

boi_token_id = 151339
eoi_token_id = 151340

llm = LLM(
    model=model_name,
    tensor_parallel_size=tp_size,
    max_model_len=max_model_len,
    trust_remote_code=True,
    enforce_eager=True,
    image_input_type="pixel_values",
    image_token_id=boi_token_id,
    image_input_shape="1,3,1120,1120",
    image_feature_size=1602,
    disable_image_processor=True
)
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(temperature=0, max_tokens=1024, stop_token_ids=stop_token_ids)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

query = 'Describe this picture.'
image = Image.open("docs/source/assets/logos/vllm-logo-text-light.png").convert('RGB')
inputs = tokenizer.apply_chat_template(
    [{"role": "user", "image": image, "content": query}],
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
    return_dict=True
)

image_tensor = inputs['images']

input_ids = inputs['input_ids'][0].tolist()
boi_token_pos, eoi_token_pos = input_ids.index(boi_token_id), input_ids.index(eoi_token_id)
input_ids = input_ids[:boi_token_pos] + [boi_token_id] * 1602 + input_ids[eoi_token_pos + 1:]

outputs = llm.generate(
    {
        "prompt_token_ids": input_ids,
        "multi_modal_data": ImagePixelData(image_tensor),
    },
    sampling_params=sampling_params
)

print(outputs[0].outputs[0].text)
The image shows a logo with the letters "LLM" in uppercase, bold font. The "L" and "M" are in a dark grey or black color, while the "L" also has a small, triangular accent on the left side in a lighter shade, possibly a pale blue or grey. The background of the logo is a gradient of dark to light grey, creating a sense of depth and shadow, which gives the logo a three-dimensional appearance. The overall design is modern and clean, with a professional look.

PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

During testing, I noticed that the model output using vLLM is inconsistent with the output from the official GLM4v code, resulting in a decrease in image understanding capabilities. I have ensured that the parameter settings are consistent (at least numerically), yet this issue persists. Could you please explain what might be causing this discrepancy?

@songxxzp
Copy link
Author

@HuggingAha Thank you for your testing. To assist you better, could you please provide more details such as the input and output, sampling parameters, and any benchmarks you are testing? If it’s a generation task, due to slight differences in the logits between the vLLM and HuggingFace’s implementation (up to the first decimal place), differences may arise after sampling to a certain length. I believe this is normal cumulative error. However, if you find a significant difference in logits or a severe performance drop when measuring benchmarks, it is likely an issue with the implementation. Please feel free to provide more details to help me fix the problem.

@bonninr
Copy link

bonninr commented Jun 26, 2024

I'm getting the following error:

File "/workspace/vllm/vllm/model_executor/models/chatglm.py", line 392, in load_weights
[rank0]: param = params_dict[name]
[rank0]: KeyError: 'transformer.vision.boi'

@songxxzp
Copy link
Author

I'm getting the following error:

File "/workspace/vllm/vllm/model_executor/models/chatglm.py", line 392, in load_weights [rank0]: param = params_dict[name] [rank0]: KeyError: 'transformer.vision.boi'

There appears to be an issue with loading the model weights. It appears that the code is trying to load the weights for the glm-4v, but the parameters for the vision encoder are absent from the weights. If the weights are correct and the configuration is accurate, this could indicate a bug in the code. To assist with troubleshooting, please provide additional details.

@shelleyo9
Copy link

looking forward that vllm can support glm-4v soon

@sjmFDU
Copy link

sjmFDU commented Jul 2, 2024

I'm getting the following error:
File "/workspace/vllm/vllm/model_executor/models/chatglm.py", line 392, in load_weights [rank0]: param = params_dict[name] [rank0]: KeyError: 'transformer.vision.boi'

There appears to be an issue with loading the model weights. It appears that the code is trying to load the weights for the glm-4v, but the parameters for the vision encoder are absent from the weights. If the weights are correct and the configuration is accurate, this could indicate a bug in the code. To assist with troubleshooting, please provide additional details.

I also find this problem. I just run the code on vllm=0.5.0 following your usage.

@sjmFDU
Copy link

sjmFDU commented Jul 2, 2024

I'm getting the following error:
File "/workspace/vllm/vllm/model_executor/models/chatglm.py", line 392, in load_weights [rank0]: param = params_dict[name] [rank0]: KeyError: 'transformer.vision.boi'

There appears to be an issue with loading the model weights. It appears that the code is trying to load the weights for the glm-4v, but the parameters for the vision encoder are absent from the weights. If the weights are correct and the configuration is accurate, this could indicate a bug in the code. To assist with troubleshooting, please provide additional details.

I also find "KeyError: 'transformer.vision.patch_embedding.cls_embedding'", 'KeyError: 'transformer.vision.eoi'. Please help me solve this problem.

@HughesZhang2021
Copy link

looking forward support soon...

@HuggingAha
Copy link

@HuggingAha Thank you for your testing. To assist you better, could you please provide more details such as the input and output, sampling parameters, and any benchmarks you are testing? If it’s a generation task, due to slight differences in the logits between the vLLM and HuggingFace’s implementation (up to the first decimal place), differences may arise after sampling to a certain length. I believe this is normal cumulative error. However, if you find a significant difference in logits or a severe performance drop when measuring benchmarks, it is likely an issue with the implementation. Please feel free to provide more details to help me fix the problem.

Based on my tests, the issue might be due to the misalignment of positional encodings, where the position_ids need to be reset. Such an operation is present in the official code of GLM4v, but I haven't found any modification like this here.

image

image

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jul 3, 2024

Based on my tests, the issue might be due to the misalignment of positional encodings, where the position_ids need to be reset. Such an operation is present in the official code of GLM4v, but I haven't found any modification like this here.

In order to use the KV cache in vLLM, we have to reserve placeholders in the multimodal embeddings before they are passed to the language model. The placeholders are then filled with the image embeddings in

https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/utils.py#L6-L41

So it is necessary to rewrite the code in such a way. Make sure that the embeddings follow the same order as the HF implementation after they are merged.

Btw, #5276 has been merged, so you can start to make changes accordingly to support dynamic number of image tokens. You may refer to this guide for more details.

@bigbigQI
Copy link

bigbigQI commented Jul 3, 2024

thanks for your contribution!

I met a problem when running with vllm server.

I start the server using Python:

python -m vllm.entrypoints.openai.api_server --model THUDM/glm-4v-9b --dtype auto --api-key token-abc123 --trust-remote-code --image-input-type "pixel_values" --image-token-id 151339 --image-input-shape "1,3,1120,1120" --image-feature-size 1602 --disable-image-processor  --enforce-eager

and then use the official OpenAI Python client to requese:

from openai import OpenAI
import base64
openai_api_key = "token-abc123"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base
)


query = '描述这张图片'

def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')

# Path to your image
image_path = "liucheng.png"

# Getting the base64 string
base64_image = encode_image(image_path)

chat_response = client.chat.completions.create(
    model="THUDM/glm-4v-9b",
    messages=[{
        "role": "user",
        "content": [
            {"type": "text", "text": query},
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                    }
            }

        ]
    }]
)

print(chat_response)

There is a error when using the openai client.

Traceback (most recent call last):
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/uvicorn/protocols/http/httptools_impl.py", line 399, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/uvicorn/middleware/proxy_headers.py", line 70, in __call__
    return await self.app(scope, receive, send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/applications.py", line 123, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/middleware/base.py", line 193, in __call__
    response_sent.set()
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/contextlib.py", line 137, in __exit__
    self.gen.throw(typ, value, traceback)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/_utils.py", line 93, in collapse_excgroups
    raise exc
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/middleware/base.py", line 191, in __call__
    response = await self.dispatch_func(request, call_next)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/entrypoints/openai/api_server.py", line 164, in authentication
    return await call_next(request)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/middleware/base.py", line 165, in call_next
    raise app_exc
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/middleware/base.py", line 151, in coro
    await self.app(scope, receive_or_disconnect, send_no_error)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/middleware/cors.py", line 85, in __call__
    await self.app(scope, receive, send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/middleware/exceptions.py", line 65, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/routing.py", line 297, in handle
    await self.app(scope, receive, send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/routing.py", line 77, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/starlette/routing.py", line 72, in app
    response = await func(request)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/fastapi/routing.py", line 278, in app
    raw_response = await run_endpoint_function(
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/fastapi/routing.py", line 191, in run_endpoint_function
    return await dependant.call(**values)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/entrypoints/openai/api_server.py", line 103, in create_chat_completion
    generator = await openai_serving_chat.create_chat_completion(
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/entrypoints/openai/serving_chat.py", line 282, in create_chat_completion
    return await self.chat_completion_full_generator(
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/entrypoints/openai/serving_chat.py", line 482, in chat_completion_full_generator
    async for res in result_generator:
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 673, in generate
    async for output in self._process_request(
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 780, in _process_request
    raise e
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 776, in _process_request
    async for request_output in stream:
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 89, in __anext__
    raise result
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 42, in _log_task_completion
    return_value = task.result()
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 532, in run_engine_loop
    has_requests_in_progress = await asyncio.wait_for(
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/asyncio/tasks.py", line 479, in wait_for
    return fut.result()
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 506, in engine_step
    request_outputs = await self.engine.step_async()
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 235, in step_async
    output = await self.model_executor.execute_model_async(
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/executor/gpu_executor.py", line 117, in execute_model_async
    output = await make_async(self.driver_worker.execute_model
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/worker/worker.py", line 280, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/worker/model_runner.py", line 735, in execute_model
    ) = self.prepare_input_tensors(seq_group_metadata_list)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/worker/model_runner.py", line 681, in prepare_input_tensors
    ) = self._prepare_model_input(seq_group_metadata_list)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/worker/model_runner.py", line 453, in _prepare_model_input
    mm_kwargs = self.multi_modal_input_processor(mm_data)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/multimodal/registry.py", line 141, in process_input
    return self._get_plugin_for_data_type(type(data)) \
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/multimodal/base.py", line 126, in process_input
    return processor(data, model_config, vlm_config)
  File "/home/larkz/miniconda3/envs/vllm_env/lib/python3.9/site-packages/vllm/multimodal/image.py", line 112, in _default_input_processor
    raise RuntimeError("No HuggingFace processor is available"
RuntimeError: No HuggingFace processor is availableto process the image object

@DarkLight1337 DarkLight1337 changed the title [Model][Bugfix] Add GLM-4v support [Model] Add GLM-4v support Jul 3, 2024
@danxuan2022
Copy link

I'm getting the following error:
File "/workspace/vllm/vllm/model_executor/models/chatglm.py", line 392, in load_weights [rank0]: param = params_dict[name] [rank0]: KeyError: 'transformer.vision.boi'

There appears to be an issue with loading the model weights. It appears that the code is trying to load the weights for the glm-4v, but the parameters for the vision encoder are absent from the weights. If the weights are correct and the configuration is accurate, this could indicate a bug in the code. To assist with troubleshooting, please provide additional details.

What's the progress so far?

@DarkLight1337
Copy link
Member

By the way, you can use format.sh to lint your code locally.

@danxuan2022
Copy link

By the way, you can use format.sh to lint your code locally.

After the native GLM-4V supports VLLM deployment, does the finetuned GLM-4V support VLLM deployment?

@B-201
Copy link
Contributor

B-201 commented Jul 18, 2024

This sample code doesn't seem to work anymore.

from PIL import Image
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImagePixelData


max_model_len, tp_size = 8192, 1
model_name = "THUDM/glm-4v-9b"

boi_token_id = 151339
eoi_token_id = 151340

llm = LLM(
    model=model_name,
    tensor_parallel_size=tp_size,
    max_model_len=max_model_len,
    trust_remote_code=True,
    enforce_eager=True,
    image_input_type="pixel_values",
    image_token_id=boi_token_id,
    image_input_shape="1,3,1120,1120",
    image_feature_size=1602,
    disable_image_processor=True
)
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(temperature=0, max_tokens=1024, stop_token_ids=stop_token_ids)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

query = 'Describe this picture.'
image = Image.open("docs/source/assets/logos/vllm-logo-text-light.png").convert('RGB')
inputs = tokenizer.apply_chat_template(
    [{"role": "user", "image": image, "content": query}],
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
    return_dict=True
)

image_tensor = inputs['images']

input_ids = inputs['input_ids'][0].tolist()
# boi_token_pos, eoi_token_pos = input_ids.index(boi_token_id), input_ids.index(eoi_token_id)
# input_ids = input_ids[:boi_token_pos] + [boi_token_id] * 1602 + input_ids[eoi_token_pos + 1:]

outputs = llm.generate(
    {
        "prompt_token_ids": input_ids,
        "multi_modal_data": ImagePixelData(image_tensor),
    },
    sampling_params=sampling_params
)

print(outputs[0].outputs[0].text)

@songxxzp
Copy link
Author

@B-201 Sorry about that. It should work now.

@songxxzp
Copy link
Author

@HuggingAha Fixed. Please see “About custom position_ids”. Thanks for the test, by the way.

@B-201
Copy link
Contributor

B-201 commented Jul 18, 2024

@B-201 Sorry about that. It should work now.

Sorry, I didn't understand what you mean. I tested it on the latest commit.

@songxxzp
Copy link
Author

@B-201 Sorry about that. It should work now.

Sorry, I didn't understand what you mean. I tested it on the latest commit.

I just edited the sample code:
#5358 (comment)

@hongsamvo
Copy link

hongsamvo commented Jul 22, 2024

@songxxzp Thank you for your contribution. Can I install vllm from this https://github.com/songxxzp/vllm/tree/glm4v. I got some error while run pip install -e.

@B-201
Copy link
Contributor

B-201 commented Jul 24, 2024

Sorry to bother you, but I was wondering if this PR will be merged soon?

@ZhangYaoFu
Copy link

Does it support batch inference of vLLM + GLM4V?

@DarkLight1337
Copy link
Member

What is the use case of custom position_ids if we can already detect that automatically? Adding position_ids just for this model seems like a code smell.

@tower823
Copy link

@B-201 Sorry about that. It should work now.

Sorry, I didn't understand what you mean. I tested it on the latest commit.

I just edited the sample code: #5358 (comment)

Hi, thanks for your contributions! I just test the latest sample code, met an error, maybe it's because the different version of transformers? can you help~
image

@alexw994
Copy link

@songxxzp Thank you for your contribution. Can I install vllm from this https://github.com/songxxzp/vllm/tree/glm4v. I got some error while run pip install -e.

You can try my PR, which includes a precompiled .whl file and support bnb 4-bit quantation. #7672

@DarkLight1337
Copy link
Member

Closing as superseded by #9242

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet