From b7ffe6621962952e2a69a6caeb6224f00bcf377d Mon Sep 17 00:00:00 2001 From: "Sebastian.W" Date: Wed, 27 Mar 2024 01:48:14 +0800 Subject: [PATCH] Enhance autogptq backend to support VL models (#1860) * Enhance autogptq backend to support VL models * update dependencies for autogptq * remove redundant auto-gptq dependency * Convert base64 to image_url for Qwen-VL model * implemented model inference for qwen-vl * remove user prompt from generated answer * fixed write image error --------- Co-authored-by: Binghua Wu --- backend/python/autogptq/autogptq.py | 56 ++++++++++++++++--- backend/python/autogptq/autogptq.yml | 13 ++++- .../transformers/transformers-nvidia.yml | 9 ++- .../transformers/transformers-rocm.yml | 6 +- .../common-env/transformers/transformers.yml | 9 ++- 5 files changed, 75 insertions(+), 18 deletions(-) diff --git a/backend/python/autogptq/autogptq.py b/backend/python/autogptq/autogptq.py index ffb37569bbc..bbafdd92085 100755 --- a/backend/python/autogptq/autogptq.py +++ b/backend/python/autogptq/autogptq.py @@ -5,12 +5,14 @@ import sys import os import time +import base64 import grpc import backend_pb2 import backend_pb2_grpc + from auto_gptq import AutoGPTQForCausalLM -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import TextGenerationPipeline _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -28,9 +30,19 @@ def LoadModel(self, request, context): if request.Device != "": device = request.Device - tokenizer = AutoTokenizer.from_pretrained(request.Model, use_fast=request.UseFastTokenizer) + # support loading local model files + model_path = os.path.join(os.environ.get('MODELS_PATH', './'), request.Model) + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=request.TrustRemoteCode) - model = AutoGPTQForCausalLM.from_quantized(request.Model, + # support model `Qwen/Qwen-VL-Chat-Int4` + if "qwen-vl" in request.Model.lower(): + self.model_name = "Qwen-VL-Chat" + model = AutoModelForCausalLM.from_pretrained(model_path, + trust_remote_code=request.TrustRemoteCode, + use_triton=request.UseTriton, + device_map="auto").eval() + else: + model = AutoGPTQForCausalLM.from_quantized(model_path, model_basename=request.ModelBaseName, use_safetensors=True, trust_remote_code=request.TrustRemoteCode, @@ -55,6 +67,11 @@ def Predict(self, request, context): if request.TopP != 0.0: top_p = request.TopP + + prompt_images = self.recompile_vl_prompt(request) + compiled_prompt = prompt_images[0] + print(f"Prompt: {compiled_prompt}", file=sys.stderr) + # Implement Predict RPC pipeline = TextGenerationPipeline( model=self.model, @@ -64,10 +81,17 @@ def Predict(self, request, context): top_p=top_p, repetition_penalty=penalty, ) - t = pipeline(request.Prompt)[0]["generated_text"] - # Remove prompt from response if present - if request.Prompt in t: - t = t.replace(request.Prompt, "") + t = pipeline(compiled_prompt)[0]["generated_text"] + print(f"generated_text: {t}", file=sys.stderr) + + if compiled_prompt in t: + t = t.replace(compiled_prompt, "") + # house keeping. Remove the image files from /tmp folder + for img_path in prompt_images[1]: + try: + os.remove(img_path) + except Exception as e: + print(f"Error removing image file: {img_path}, {e}", file=sys.stderr) return backend_pb2.Result(message=bytes(t, encoding='utf-8')) @@ -78,6 +102,24 @@ def PredictStream(self, request, context): # Not implemented yet return self.Predict(request, context) + def recompile_vl_prompt(self, request): + prompt = request.Prompt + image_paths = [] + + if "qwen-vl" in self.model_name.lower(): + # request.Images is an array which contains base64 encoded images. Iterate the request.Images array, decode and save each image to /tmp folder with a random filename. + # Then, save the image file paths to an array "image_paths". + # read "request.Prompt", replace "[img-%d]" with the image file paths in the order they appear in "image_paths". Save the new prompt to "prompt". + for i, img in enumerate(request.Images): + timestamp = str(int(time.time() * 1000)) # Generate timestamp + img_path = f"/tmp/vl-{timestamp}.jpg" # Use timestamp in filename + with open(img_path, "wb") as f: + f.write(base64.b64decode(img)) + image_paths.append(img_path) + prompt = prompt.replace(f"[img-{i}]", "" + img_path + ",") + else: + prompt = request.Prompt + return (prompt, image_paths) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) diff --git a/backend/python/autogptq/autogptq.yml b/backend/python/autogptq/autogptq.yml index 19b8e41d1c2..d22b354eb91 100644 --- a/backend/python/autogptq/autogptq.yml +++ b/backend/python/autogptq/autogptq.yml @@ -1,3 +1,7 @@ +#### +# Attention! This file is abandoned. +# Please use the ../common-env/transformers/transformers.yml file to manage dependencies. +### name: autogptq channels: - defaults @@ -24,12 +28,12 @@ dependencies: - xz=5.4.2=h5eee18b_0 - zlib=1.2.13=h5eee18b_0 - pip: - - accelerate==0.23.0 + - accelerate==0.27.0 - aiohttp==3.8.5 - aiosignal==1.3.1 - async-timeout==4.0.3 - attrs==23.1.0 - - auto-gptq==0.4.2 + - auto-gptq==0.7.1 - certifi==2023.7.22 - charset-normalizer==3.3.0 - datasets==2.14.5 @@ -59,6 +63,7 @@ dependencies: - nvidia-nccl-cu12==2.18.1 - nvidia-nvjitlink-cu12==12.2.140 - nvidia-nvtx-cu12==12.1.105 + - optimum==1.17.1 - packaging==23.2 - pandas==2.1.1 - peft==0.5.0 @@ -75,9 +80,11 @@ dependencies: - six==1.16.0 - sympy==1.12 - tokenizers==0.14.0 - - torch==2.1.0 - tqdm==4.66.1 + - torch==2.2.1 + - torchvision==0.17.1 - transformers==4.34.0 + - transformers_stream_generator==0.0.5 - triton==2.1.0 - typing-extensions==4.8.0 - tzdata==2023.3 diff --git a/backend/python/common-env/transformers/transformers-nvidia.yml b/backend/python/common-env/transformers/transformers-nvidia.yml index 7daafe51804..553612344d9 100644 --- a/backend/python/common-env/transformers/transformers-nvidia.yml +++ b/backend/python/common-env/transformers/transformers-nvidia.yml @@ -24,10 +24,11 @@ dependencies: - xz=5.4.2=h5eee18b_0 - zlib=1.2.13=h5eee18b_0 - pip: - - accelerate==0.23.0 + - accelerate==0.27.0 - aiohttp==3.8.5 - aiosignal==1.3.1 - async-timeout==4.0.3 + - auto-gptq==0.7.1 - attrs==23.1.0 - bark==0.1.5 - bitsandbytes==0.43.0 @@ -69,6 +70,7 @@ dependencies: - nvidia-nccl-cu12==2.18.1 - nvidia-nvjitlink-cu12==12.2.140 - nvidia-nvtx-cu12==12.1.105 + - optimum==1.17.1 - packaging==23.2 - pandas - peft==0.5.0 @@ -87,7 +89,8 @@ dependencies: - six==1.16.0 - sympy==1.12 - tokenizers - - torch==2.1.2 + - torch==2.2.1 + - torchvision==0.17.1 - torchaudio==2.1.2 - tqdm==4.66.1 - triton==2.1.0 @@ -95,7 +98,6 @@ dependencies: - tzdata==2023.3 - urllib3==1.26.17 - xxhash==3.4.1 - - auto-gptq==0.6.0 - yarl==1.9.2 - soundfile - langid @@ -116,5 +118,6 @@ dependencies: - vocos - vllm==0.3.2 - transformers>=4.38.2 # Updated Version + - transformers_stream_generator==0.0.5 - xformers==0.0.23.post1 prefix: /opt/conda/envs/transformers diff --git a/backend/python/common-env/transformers/transformers-rocm.yml b/backend/python/common-env/transformers/transformers-rocm.yml index 5c18d301dc1..fa245bf4cec 100644 --- a/backend/python/common-env/transformers/transformers-rocm.yml +++ b/backend/python/common-env/transformers/transformers-rocm.yml @@ -26,7 +26,8 @@ dependencies: - pip: - --pre - --extra-index-url https://download.pytorch.org/whl/nightly/ - - accelerate==0.23.0 + - accelerate==0.27.0 + - auto-gptq==0.7.1 - aiohttp==3.8.5 - aiosignal==1.3.1 - async-timeout==4.0.3 @@ -82,7 +83,6 @@ dependencies: - triton==2.1.0 - typing-extensions==4.8.0 - tzdata==2023.3 - - auto-gptq==0.6.0 - urllib3==1.26.17 - xxhash==3.4.1 - yarl==1.9.2 @@ -90,6 +90,7 @@ dependencies: - langid - wget - unidecode + - optimum==1.17.1 - pyopenjtalk-prebuilt - pypinyin - inflect @@ -105,5 +106,6 @@ dependencies: - vocos - vllm==0.3.2 - transformers>=4.38.2 # Updated Version + - transformers_stream_generator==0.0.5 - xformers==0.0.23.post1 prefix: /opt/conda/envs/transformers diff --git a/backend/python/common-env/transformers/transformers.yml b/backend/python/common-env/transformers/transformers.yml index 5726abaf37c..bdf8c36fb63 100644 --- a/backend/python/common-env/transformers/transformers.yml +++ b/backend/python/common-env/transformers/transformers.yml @@ -24,9 +24,10 @@ dependencies: - xz=5.4.2=h5eee18b_0 - zlib=1.2.13=h5eee18b_0 - pip: - - accelerate==0.23.0 + - accelerate==0.27.0 - aiohttp==3.8.5 - aiosignal==1.3.1 + - auto-gptq==0.7.1 - async-timeout==4.0.3 - attrs==23.1.0 - bark==0.1.5 @@ -56,6 +57,7 @@ dependencies: - multiprocess==0.70.15 - networkx - numpy==1.26.0 + - optimum==1.17.1 - packaging==23.2 - pandas - peft==0.5.0 @@ -74,13 +76,13 @@ dependencies: - six==1.16.0 - sympy==1.12 - tokenizers - - torch==2.1.2 + - torch==2.2.1 + - torchvision==0.17.1 - torchaudio==2.1.2 - tqdm==4.66.1 - triton==2.1.0 - typing-extensions==4.8.0 - tzdata==2023.3 - - auto-gptq==0.6.0 - urllib3==1.26.17 - xxhash==3.4.1 - yarl==1.9.2 @@ -103,5 +105,6 @@ dependencies: - vocos - vllm==0.3.2 - transformers>=4.38.2 # Updated Version + - transformers_stream_generator==0.0.5 - xformers==0.0.23.post1 prefix: /opt/conda/envs/transformers