Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Adds support for bias in LoRA #5733

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e491d72
LoRA Bias Support
Jun 21, 2024
b0ed274
Minor changes
Jun 21, 2024
fced7ec
Ignore types to avoid error
Jun 21, 2024
575032f
Merge branch 'main' of https://github.com/vllm-project/vllm into bias…
Jun 26, 2024
882a8e8
Merge branch 'main' of https://github.com/vllm-project/vllm into bias…
Jun 27, 2024
29a58c2
enable-lora-bias flag
Jun 27, 2024
7e64588
Resolved conflicts
Jun 27, 2024
06ba6cf
yapf formatting
Jun 27, 2024
84a37ea
yapf formatting
Jun 27, 2024
cd1bb03
yapf formatting
Jun 27, 2024
d73cecb
LoRA Bias Support
Jun 21, 2024
857152b
Minor changes
Jun 21, 2024
c02bee6
Ignore types to avoid error
Jun 21, 2024
0eaaecb
enable-lora-bias flag
Jun 27, 2024
5c8acd0
yapf formatting
Jun 27, 2024
f261cf6
yapf formatting
Jun 27, 2024
1c78eb2
yapf formatting
Jun 27, 2024
387be43
Merge branch 'bias-for-lora' of github.com:followumesh/vllm into bias…
Jul 9, 2024
4845dae
E2E test for lora bias
Jul 26, 2024
1562590
Merged main
Jul 26, 2024
e0eca8a
isort imports
Jul 26, 2024
2aacf10
yapf fix
Jul 26, 2024
942f2ab
Mixing bias and non-bias lora in a batch
Jul 28, 2024
f7deaef
Formatting changes
Jul 31, 2024
ecd753d
Merge remote-tracking branch 'upstream/main' into bias-for-lora
Jul 31, 2024
4d1b8f0
Merge: punica api changes
Aug 8, 2024
7e8bad0
Merge remote-tracking branch 'upstream/main' into bias-for-lora
Aug 21, 2024
3aeb63d
Removed assert for lora check
Aug 21, 2024
808e92c
Ruff: Merged if
Aug 21, 2024
be2ed6b
Merged main
Oct 8, 2024
0835078
Merge remote-tracking branch 'upstream/main' into bias-for-lora
Oct 8, 2024
7454ae9
Merge remote-tracking branch 'upstream/main' into bias-for-lora
Oct 8, 2024
9a046f1
Merge remote-tracking branch 'upstream/main' into bias-for-lora
Oct 9, 2024
8d44e86
Incorporated Suggestions
Oct 28, 2024
cda128c
Minor commit
Oct 28, 2024
c584a36
Merge remote-tracking branch 'upstream/main' into bias-for-lora
Oct 28, 2024
d58851e
Merge remote-tracking branch 'upstream/main' into bias-for-lora
Oct 28, 2024
7db0ded
Failure without --enable-lora-bias flag
Oct 28, 2024
7128fa0
Error: bias is present and not enabled
Nov 9, 2024
b8dc556
Merge remote-tracking branch 'upstream/main' into bias-for-lora
Nov 9, 2024
06162b9
Formatting fix
Nov 9, 2024
b762fe2
Formatting fix
Nov 9, 2024
3b6beb0
Check for list of None
Nov 9, 2024
70a40f6
Check for list of None
Nov 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/punica/bgmv/bgmv_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 49408) \
f(in_T, out_T, W_T, narrow, 60544) \
f(in_T, out_T, W_T, narrow, 60672) \
f(in_T, out_T, W_T, narrow, 64000) \
Expand Down Expand Up @@ -182,6 +183,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 49408, narrow) \
f(in_T, out_T, W_T, 60544, narrow) \
f(in_T, out_T, W_T, 60672, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
Expand Down
254 changes: 254 additions & 0 deletions tests/lora/test_lora_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import pytest
import torch

from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice

from .utils import DummyLoRAManager

TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4]
QKV_TENSOR_SIZES = [
(8192, 1024, 1024),
(8192 // 8, 1024 // 8, 1024 // 8),
(4096, 4096, 4096),
(4096 // 2, 4096 // 2, 4096 // 2),
]
BATCH_SIZES = [8, 32, 256]
RANKS = [8]
DTYPES = [torch.float16]
TOLERANCES = {
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}


@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora(m, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()

module_name = "module"
weight = torch.rand([m, n], device="cuda", dtype=dtype)

manager.init_random_lora(module_name, weight, rank=rank)
lora = manager.get_module_lora(module_name)

input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling

lora_a_stack = torch.zeros(8,
1,
lora.lora_a.shape[1],
lora.lora_a.shape[0],
device="cuda",
dtype=dtype)
lora_b_stack = torch.zeros(8,
1,
lora.lora_b.shape[1],
lora.lora_b.shape[0],
device="cuda",
dtype=dtype)
lora_bias = torch.zeros(8,
1,
lora.lora_b.shape[1],
device="cuda",
dtype=dtype)

for i in range(lora_a_stack.shape[0]):
lora_a_stack[i][0] = lora.lora_a.T
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T

output = torch.zeros(k, m, device="cuda", dtype=dtype)
_apply_lora(
input, lora_a_stack, lora_b_stack,
torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"),
output, lora_bias)

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.full((len(input), ), -1, device="cuda"),
output, lora_bias)
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()


@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
if m % 2 != 0:
pytest.skip("m must be divisible by 2")
if m // 2 not in TENSOR_SIZES:
pytest.skip("m//2 must be in TENSOR_SIZES")

manager = DummyLoRAManager()

module_name = "module"
weight = torch.rand([m // 2, n], device="cuda", dtype=dtype)

manager.init_random_lora(module_name + "1", weight, rank=rank)
lora_1 = manager.get_module_lora(module_name + "1")
manager.init_random_lora(module_name + "2", weight, rank=rank)
lora_2 = manager.get_module_lora(module_name + "2")

input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = torch.cat([
input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling,
input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling
],
dim=1)

lora_a_stacks = [
torch.zeros(8,
1,
lora_1.lora_a.shape[1],
lora_1.lora_a.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
lora_1.lora_b.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_bias_stacks = [
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
device="cuda",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_1.lora_a.T
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
lora_a_stacks[1][i][0] = lora_2.lora_a.T
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T

output = torch.zeros(k, m, device="cuda", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="cuda"), output, (m // 2, m // 2),
lora_bias_stacks)

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="cuda"),
output, (m // 2, m // 2), lora_bias_stacks)
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()


@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()

module_name = "module"
weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype)
weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype)

manager.init_random_lora(module_name + "q", weight_q, rank=rank)
lora_q = manager.get_module_lora(module_name + "q")
manager.init_random_lora(module_name + "k", weight_kv, rank=rank)
lora_k = manager.get_module_lora(module_name + "k")
manager.init_random_lora(module_name + "v", weight_kv, rank=rank)
lora_v = manager.get_module_lora(module_name + "v")

input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = torch.cat([
input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling,
input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling,
input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling
],
dim=1)

lora_a_stacks = [
torch.zeros(8,
1,
lora_q.lora_a.shape[1],
lora_q.lora_a.shape[0],
device="cuda",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_a.shape[1],
lora_k.lora_a.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
lora_q.lora_b.shape[0],
device="cuda",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
lora_k.lora_b.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_bias_stacks = [
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
device="cuda",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
device="cuda",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_q.lora_a.T
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
lora_a_stacks[1][i][0] = lora_k.lora_a.T
lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T
lora_a_stacks[2][i][0] = lora_v.lora_a.T
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T

output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="cuda"), output, (qkv[0], qkv[1], qkv[2]),
lora_bias_stacks)

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="cuda"),
output, (qkv[0], qkv[1], qkv[2]),
lora_bias_stacks)
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()
27 changes: 27 additions & 0 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def apply(self, x: torch.Tensor,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output

if self.bias_stacked is not None:
self.bias_stacked = self.bias_stacked.view(-1,
self.bias_stacked.shape[-1])
self.bias_stacked = self.bias_stacked[self.indices]
output += self.bias_stacked

output = output.view(*out_orig_shape)
return output

Expand Down Expand Up @@ -119,6 +125,13 @@ def _mcp_apply(x, bias, layer):
layer.lora_b_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0,
left_offset, shard_size)
if layer.bias_stacked is not None:
bias = layer.bias_stacked[idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[layer.indices[:layer.indices_len[0]]]
output[:, left_offset: left_offset + shard_size] += bias

left_offset += shard_size

output = output.view(*out_orig_shape)
Expand Down Expand Up @@ -277,6 +290,15 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
lora_b = lora_b[:, start_idx:end_idx]
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
shard_size = self.bias_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias

def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)

Expand All @@ -302,6 +324,11 @@ def apply(self, x: torch.Tensor) -> torch.Tensor:
self.indices[:self.indices_len[0]], 0, 1.0,
start_idx, shard_size)

if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.indices[:self.indices_len[0]]]
output += bias

output = output.view(*out_orig_shape)
return output

Expand Down
Loading
Loading