Skip to content

Commit

Permalink
[Model] Add base class for LoRA-supported models (vllm-project#5018)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and prashantgupta24 committed Jul 1, 2024
1 parent 205d24f commit 8399340
Show file tree
Hide file tree
Showing 20 changed files with 270 additions and 75 deletions.
3 changes: 3 additions & 0 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Using LoRA adapters
===================

This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.

LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`.

Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
them locally with

Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Sequence as GenericSequence

import torch
import torch.types

from vllm.utils import is_pin_memory_available

Expand Down Expand Up @@ -64,7 +65,7 @@ def create_dummy_lora_weights(
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
Expand Down
6 changes: 3 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import LRUCache, is_pin_memory_available

logger = init_logger(__name__)
Expand Down Expand Up @@ -363,7 +364,7 @@ class LoRAModelManager:

def __init__(
self,
model: nn.Module,
model: SupportsLoRA,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
Expand Down Expand Up @@ -411,7 +412,7 @@ def __init__(
# embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4

self.model: nn.Module = model
self.model = model
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules)
Expand All @@ -428,7 +429,6 @@ def __init__(
self._active_loras: Dict[int, None] = {}
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self

@property
def capacity(self) -> int:
Expand Down
20 changes: 13 additions & 7 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_tpu

Expand Down Expand Up @@ -64,26 +65,31 @@ def _get_quantization_config(


def _get_model_initialization_kwargs(
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
vlm_config: Optional[VisionLanguageConfig],
) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}
if hasattr(model_class, "supported_lora_modules"):

if supports_lora(model_class):
# lora_config=None is used to disable LoRA
extra_kwargs["lora_config"] = lora_config
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
elif issubclass(model_class, VisionLanguageModelBase):
if vision_language_config is None:

if supports_vision(model_class):
if vlm_config is None:
raise ValueError("Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")

extra_kwargs["vision_language_config"] = vision_language_config
extra_kwargs["vlm_config"] = vlm_config

return extra_kwargs


Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput

from .interfaces import SupportsLoRA


def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
Expand Down Expand Up @@ -292,7 +294,9 @@ def forward(
return hidden_states


class BaiChuanBaseForCausalLM(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True

packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
Expand All @@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module):

def __init__(
self,
config,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, cache_config,
quant_config)
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig

from .interfaces import SupportsLoRA


class GLMAttention(nn.Module):

Expand Down Expand Up @@ -322,7 +324,9 @@ def forward(
return hidden_states


class ChatGLMForCausalLM(nn.Module):
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True

packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
Expand All @@ -345,7 +349,10 @@ def __init__(
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/decilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Iterable, Optional, Tuple

import torch
from transformers import PretrainedConfig
from transformers import LlamaConfig

from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):

def __init__(
self,
config: Optional[PretrainedConfig] = None,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput

from .interfaces import SupportsLoRA

logger = init_logger(__name__)


Expand Down Expand Up @@ -288,7 +290,9 @@ def forward(
return hidden_states


class GemmaForCausalLM(nn.Module):
class GemmaForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -319,9 +323,11 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__()

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput

from .interfaces import SupportsLoRA


class GPTBigCodeAttention(nn.Module):

Expand Down Expand Up @@ -230,7 +232,9 @@ def forward(
return hidden_states


class GPTBigCodeForCausalLM(nn.Module):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True

packed_modules_mapping = {"c_attn": ["c_attn"]}

supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
Expand All @@ -250,7 +254,10 @@ def __init__(
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
Expand Down
130 changes: 130 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable)

from typing_extensions import TypeGuard

from vllm.config import LoRAConfig, VisionLanguageConfig
from vllm.logger import init_logger

logger = init_logger(__name__)


@runtime_checkable
class SupportsVision(Protocol):
"""The interface required for all vision language models (VLMs)."""

supports_vision: ClassVar[Literal[True]]

def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]

def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
...


@overload
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
...


@overload
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
...


def supports_vision(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
if isinstance(model, type):
return isinstance(model, _SupportsVisionType)

return isinstance(model, SupportsVision)


@runtime_checkable
class SupportsLoRA(Protocol):
"""The interface required for all models that support LoRA."""

supports_lora: ClassVar[Literal[True]]

packed_modules_mapping: ClassVar[Dict[str, List[str]]]
supported_lora_modules: ClassVar[List[str]]
embedding_modules: ClassVar[Dict[str, str]]
embedding_padding_modules: ClassVar[List[str]]

# lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
supports_lora: Literal[True]

packed_modules_mapping: Dict[str, List[str]]
supported_lora_modules: List[str]
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]

def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...


@overload
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
...


@overload
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
...


def supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
result = _supports_lora(model)

if not result:
lora_attrs = (
"packed_modules_mapping",
"supported_lora_modules",
"embedding_modules",
"embedding_padding_modules",
)
missing_attrs = tuple(attr for attr in lora_attrs
if not hasattr(model, attr))

if getattr(model, "supports_lora", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_lora=True`, "
"but is missing LoRA-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all LoRA-specific attributes, "
"but does not set `supports_lora=True`.", model)

return result


def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)

return isinstance(model, SupportsLoRA)
Loading

0 comments on commit 8399340

Please sign in to comment.