Skip to content

Commit

Permalink
Enhance autogptq backend to support VL models (#1860)
Browse files Browse the repository at this point in the history
* Enhance autogptq backend to support VL models

* update dependencies for autogptq

* remove redundant auto-gptq dependency

* Convert base64 to image_url for Qwen-VL model

* implemented model inference for qwen-vl

* remove user prompt from generated answer

* fixed write image error

---------

Co-authored-by: Binghua Wu <[email protected]>
  • Loading branch information
thiner and Binghua Wu committed Mar 26, 2024
1 parent e58410f commit b7ffe66
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 18 deletions.
56 changes: 49 additions & 7 deletions backend/python/autogptq/autogptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import sys
import os
import time
import base64

import grpc
import backend_pb2
import backend_pb2_grpc

from auto_gptq import AutoGPTQForCausalLM
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextGenerationPipeline

_ONE_DAY_IN_SECONDS = 60 * 60 * 24
Expand All @@ -28,9 +30,19 @@ def LoadModel(self, request, context):
if request.Device != "":
device = request.Device

tokenizer = AutoTokenizer.from_pretrained(request.Model, use_fast=request.UseFastTokenizer)
# support loading local model files
model_path = os.path.join(os.environ.get('MODELS_PATH', './'), request.Model)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=request.TrustRemoteCode)

model = AutoGPTQForCausalLM.from_quantized(request.Model,
# support model `Qwen/Qwen-VL-Chat-Int4`
if "qwen-vl" in request.Model.lower():
self.model_name = "Qwen-VL-Chat"
model = AutoModelForCausalLM.from_pretrained(model_path,
trust_remote_code=request.TrustRemoteCode,
use_triton=request.UseTriton,
device_map="auto").eval()
else:
model = AutoGPTQForCausalLM.from_quantized(model_path,
model_basename=request.ModelBaseName,
use_safetensors=True,
trust_remote_code=request.TrustRemoteCode,
Expand All @@ -55,6 +67,11 @@ def Predict(self, request, context):
if request.TopP != 0.0:
top_p = request.TopP


prompt_images = self.recompile_vl_prompt(request)
compiled_prompt = prompt_images[0]
print(f"Prompt: {compiled_prompt}", file=sys.stderr)

# Implement Predict RPC
pipeline = TextGenerationPipeline(
model=self.model,
Expand All @@ -64,10 +81,17 @@ def Predict(self, request, context):
top_p=top_p,
repetition_penalty=penalty,
)
t = pipeline(request.Prompt)[0]["generated_text"]
# Remove prompt from response if present
if request.Prompt in t:
t = t.replace(request.Prompt, "")
t = pipeline(compiled_prompt)[0]["generated_text"]
print(f"generated_text: {t}", file=sys.stderr)

if compiled_prompt in t:
t = t.replace(compiled_prompt, "")
# house keeping. Remove the image files from /tmp folder
for img_path in prompt_images[1]:
try:
os.remove(img_path)
except Exception as e:
print(f"Error removing image file: {img_path}, {e}", file=sys.stderr)

return backend_pb2.Result(message=bytes(t, encoding='utf-8'))

Expand All @@ -78,6 +102,24 @@ def PredictStream(self, request, context):
# Not implemented yet
return self.Predict(request, context)

def recompile_vl_prompt(self, request):
prompt = request.Prompt
image_paths = []

if "qwen-vl" in self.model_name.lower():
# request.Images is an array which contains base64 encoded images. Iterate the request.Images array, decode and save each image to /tmp folder with a random filename.
# Then, save the image file paths to an array "image_paths".
# read "request.Prompt", replace "[img-%d]" with the image file paths in the order they appear in "image_paths". Save the new prompt to "prompt".
for i, img in enumerate(request.Images):
timestamp = str(int(time.time() * 1000)) # Generate timestamp
img_path = f"/tmp/vl-{timestamp}.jpg" # Use timestamp in filename
with open(img_path, "wb") as f:
f.write(base64.b64decode(img))
image_paths.append(img_path)
prompt = prompt.replace(f"[img-{i}]", "<img>" + img_path + "</img>,")
else:
prompt = request.Prompt
return (prompt, image_paths)

def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
Expand Down
13 changes: 10 additions & 3 deletions backend/python/autogptq/autogptq.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
####
# Attention! This file is abandoned.
# Please use the ../common-env/transformers/transformers.yml file to manage dependencies.
###
name: autogptq
channels:
- defaults
Expand All @@ -24,12 +28,12 @@ dependencies:
- xz=5.4.2=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- accelerate==0.23.0
- accelerate==0.27.0
- aiohttp==3.8.5
- aiosignal==1.3.1
- async-timeout==4.0.3
- attrs==23.1.0
- auto-gptq==0.4.2
- auto-gptq==0.7.1
- certifi==2023.7.22
- charset-normalizer==3.3.0
- datasets==2.14.5
Expand Down Expand Up @@ -59,6 +63,7 @@ dependencies:
- nvidia-nccl-cu12==2.18.1
- nvidia-nvjitlink-cu12==12.2.140
- nvidia-nvtx-cu12==12.1.105
- optimum==1.17.1
- packaging==23.2
- pandas==2.1.1
- peft==0.5.0
Expand All @@ -75,9 +80,11 @@ dependencies:
- six==1.16.0
- sympy==1.12
- tokenizers==0.14.0
- torch==2.1.0
- tqdm==4.66.1
- torch==2.2.1
- torchvision==0.17.1
- transformers==4.34.0
- transformers_stream_generator==0.0.5
- triton==2.1.0
- typing-extensions==4.8.0
- tzdata==2023.3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ dependencies:
- xz=5.4.2=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- accelerate==0.23.0
- accelerate==0.27.0
- aiohttp==3.8.5
- aiosignal==1.3.1
- async-timeout==4.0.3
- auto-gptq==0.7.1
- attrs==23.1.0
- bark==0.1.5
- bitsandbytes==0.43.0
Expand Down Expand Up @@ -69,6 +70,7 @@ dependencies:
- nvidia-nccl-cu12==2.18.1
- nvidia-nvjitlink-cu12==12.2.140
- nvidia-nvtx-cu12==12.1.105
- optimum==1.17.1
- packaging==23.2
- pandas
- peft==0.5.0
Expand All @@ -87,15 +89,15 @@ dependencies:
- six==1.16.0
- sympy==1.12
- tokenizers
- torch==2.1.2
- torch==2.2.1
- torchvision==0.17.1
- torchaudio==2.1.2
- tqdm==4.66.1
- triton==2.1.0
- typing-extensions==4.8.0
- tzdata==2023.3
- urllib3==1.26.17
- xxhash==3.4.1
- auto-gptq==0.6.0
- yarl==1.9.2
- soundfile
- langid
Expand All @@ -116,5 +118,6 @@ dependencies:
- vocos
- vllm==0.3.2
- transformers>=4.38.2 # Updated Version
- transformers_stream_generator==0.0.5
- xformers==0.0.23.post1
prefix: /opt/conda/envs/transformers
6 changes: 4 additions & 2 deletions backend/python/common-env/transformers/transformers-rocm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies:
- pip:
- --pre
- --extra-index-url https://download.pytorch.org/whl/nightly/
- accelerate==0.23.0
- accelerate==0.27.0
- auto-gptq==0.7.1
- aiohttp==3.8.5
- aiosignal==1.3.1
- async-timeout==4.0.3
Expand Down Expand Up @@ -82,14 +83,14 @@ dependencies:
- triton==2.1.0
- typing-extensions==4.8.0
- tzdata==2023.3
- auto-gptq==0.6.0
- urllib3==1.26.17
- xxhash==3.4.1
- yarl==1.9.2
- soundfile
- langid
- wget
- unidecode
- optimum==1.17.1
- pyopenjtalk-prebuilt
- pypinyin
- inflect
Expand All @@ -105,5 +106,6 @@ dependencies:
- vocos
- vllm==0.3.2
- transformers>=4.38.2 # Updated Version
- transformers_stream_generator==0.0.5
- xformers==0.0.23.post1
prefix: /opt/conda/envs/transformers
9 changes: 6 additions & 3 deletions backend/python/common-env/transformers/transformers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ dependencies:
- xz=5.4.2=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- accelerate==0.23.0
- accelerate==0.27.0
- aiohttp==3.8.5
- aiosignal==1.3.1
- auto-gptq==0.7.1
- async-timeout==4.0.3
- attrs==23.1.0
- bark==0.1.5
Expand Down Expand Up @@ -56,6 +57,7 @@ dependencies:
- multiprocess==0.70.15
- networkx
- numpy==1.26.0
- optimum==1.17.1
- packaging==23.2
- pandas
- peft==0.5.0
Expand All @@ -74,13 +76,13 @@ dependencies:
- six==1.16.0
- sympy==1.12
- tokenizers
- torch==2.1.2
- torch==2.2.1
- torchvision==0.17.1
- torchaudio==2.1.2
- tqdm==4.66.1
- triton==2.1.0
- typing-extensions==4.8.0
- tzdata==2023.3
- auto-gptq==0.6.0
- urllib3==1.26.17
- xxhash==3.4.1
- yarl==1.9.2
Expand All @@ -103,5 +105,6 @@ dependencies:
- vocos
- vllm==0.3.2
- transformers>=4.38.2 # Updated Version
- transformers_stream_generator==0.0.5
- xformers==0.0.23.post1
prefix: /opt/conda/envs/transformers

0 comments on commit b7ffe66

Please sign in to comment.