Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Adds support for bias in LoRA #5733

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)


@pytest.fixture(scope="session")
def lora_bias_files():
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")


@pytest.fixture(scope="session")
def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test
Expand Down
52 changes: 52 additions & 0 deletions tests/lora/test_lora_bias_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

MODEL_PATH = "ibm-granite/granite-3b-code-base"


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
generated_texts: List[str] = []
for output in outputs:
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
return generated_texts


@pytest.mark.parametrize("lora_bias", [True, False])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_lora_rank=8,
max_loras=1,
enable_lora_bias=lora_bias,
tensor_parallel_size=1,
fully_sharded_loras=fully_sharded)

print("lora adapter created")
output1 = do_sample(llm, lora_bias_files, lora_id=0)

print("lora")
output2 = do_sample(llm, lora_bias_files, lora_id=1)

if lora_bias:
assert output1 != output2
else:
assert output1 == output2
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,7 @@ class LoRAConfig:
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
bias_enabled: bool = False

def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[dict] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_prompt_adapter: bool = False
Expand Down Expand Up @@ -481,6 +482,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--enable-lora-bias',
action='store_true',
help='If True, enable bias for LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
Expand Down Expand Up @@ -919,6 +923,7 @@ def create_engine_config(self, ) -> EngineConfig:
and parallel_config.use_ray),
)
lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
Expand Down
33 changes: 33 additions & 0 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def apply(self, x: torch.Tensor,
self.lora_b_stacked,
add_input=True)
# now have column partitioned output

if self.bias_stacked is not None:
self.bias_stacked = self.bias_stacked.view(
-1, self.bias_stacked.shape[-1])
self.bias_stacked = self.bias_stacked[
self.punica_wrapper.token_lora_indices]
output += self.bias_stacked

output = output.view(*out_orig_shape)
return output

Expand Down Expand Up @@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]

if layer.bias_stacked is not None:
bias = layer.bias_stacked[idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[layer.punica_wrapper.token_lora_indices]
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
output[:, left_offset:left_offset + shard_size] += bias

layer.punica_wrapper.add_expand_slice(
output,
buffers[idx],
Expand Down Expand Up @@ -295,6 +312,15 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
lora_b = lora_b[:, start_idx:end_idx]
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
shard_size = self.bias_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias

def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)

Expand All @@ -318,6 +344,13 @@ def apply(self, x: torch.Tensor) -> torch.Tensor:
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size

if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias

self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx,
shard_size)
Expand Down
Loading
Loading