Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Add vLLM Backend for FORMAT.GPTQ #190

Merged
merged 17 commits into from
Jul 10, 2024
29 changes: 27 additions & 2 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..utils.importer import select_quant_linear
from ..utils.marlin import (_validate_marlin_compatibility,
_validate_marlin_device_support, prepare_model_for_marlin_load)
from ..utils.vllm import load_model_by_vllm, vllm_generate
from ..utils.model import (auto_dtype_from_config, convert_gptq_v1_to_v2_format, convert_gptq_v2_to_v1_format,
find_layers, get_checkpoints, get_device, get_module_by_name_prefix,
get_module_by_name_suffix, get_moe_layer_modules, gptqmodel_post_init, make_quant,
Expand Down Expand Up @@ -590,8 +591,13 @@ def forward(self, *args, **kwargs):

def generate(self, **kwargs):
"""shortcut for model.generate"""
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)
load_format = kwargs.pop('load_format', None)
Qubitium marked this conversation as resolved.
Show resolved Hide resolved
with torch.inference_mode():
if load_format == 'vllm':
return vllm_generate(self.model, **kwargs)
else:
torch.amp.autocast(device_type=self.device.type)
Qubitium marked this conversation as resolved.
Show resolved Hide resolved
return self.model.generate(**kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
Expand Down Expand Up @@ -1126,6 +1132,25 @@ def skip(*args, **kwargs):
)
load_checkpoint_in_model = True
quantize_config.format = FORMAT.GPTQ_V2
if backend == BACKEND.VLLM:
if quantize_config.format != FORMAT.GPTQ and quantize_config.format != FORMAT.GPTQ_V2:
Qubitium marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"{backend} backend only supports FORMAT.GPTQ.{quantize_config.format}")

if backend == BACKEND.VLLM:
model = load_model_by_vllm(
model=model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
model.config = model.llm_engine.model_config
model.config.model_type = "vllm"
Qubitium marked this conversation as resolved.
Show resolved Hide resolved

return cls(
model,
quantized=True,
quantize_config=quantize_config,
qlinear_kernel=None,
)

if backend == BACKEND.MARLIN:
if is_sharded:
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BACKEND(Enum):
MARLIN = 6
BITBLAS = 7
QBITS = 8

VLLM = 9

def get_backend(backend: str):
try:
Expand Down
24 changes: 24 additions & 0 deletions gptqmodel/utils/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

from vllm import LLM, SamplingParams

def load_model_by_vllm(
model,
trust_remote_code,
**kwargs,
):
model = LLM(
model=model,
trust_remote_code=trust_remote_code,
Qubitium marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
)

return model

def vllm_generate(
model,
**kwargs,
):
prompts = kwargs.pop("prompts", None)
sampling_params = kwargs.pop("sampling_params", None)
outputs = model.generate(prompts, sampling_params)
return outputs
36 changes: 36 additions & 0 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch

import unittest # noqa: E402
from vllm import LLM, SamplingParams # noqa: E402
from gptqmodel import BACKEND, GPTQModel # noqa: E402

class TestLoadVLLM(unittest.TestCase):
MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def test_load_vllm(self):
model = GPTQModel.from_quantized(
self.MODEL_ID,
device="cuda:0",
backend=BACKEND.VLLM,
trust_remote_code=True,
Qubitium marked this conversation as resolved.
Show resolved Hide resolved
gpu_memory_utilization=0.2
)
outputs = model.generate(
load_format="vllm",
Qubitium marked this conversation as resolved.
Show resolved Hide resolved
prompts=self.prompts,
sampling_params=self.sampling_params,
)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertTrue(outputs is not None)
Loading