Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Add fully sharded layer for QKVParallelLinearWithLora #5665

Merged
merged 7 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
14 changes: 9 additions & 5 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_baichuan_lora(baichuan_lora_files):


@pytest.mark.skip("Requires multiple GPUs")
def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
Expand All @@ -75,7 +76,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)

del llm_tp1
Expand All @@ -87,7 +89,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)

del llm_tp2
Expand All @@ -101,10 +104,11 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)

del llm_tp4
cleanup()

assert output_tp1 == output_tp4
assert output_tp1 == output_tp4
7 changes: 5 additions & 2 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
Expand Down Expand Up @@ -684,7 +685,9 @@ def create_column_parallel_packed_layer():
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = QKVParallelLinearWithLora(linear)
lora_linear = QKVParallelLinearWithLora(
linear
) if not fully_shard else QKVParallelLinearWithShardedLora(linear)

@dataclass
class FakeConfig:
Expand Down
58 changes: 55 additions & 3 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level

Expand Down Expand Up @@ -90,11 +91,11 @@ def can_replace_layer(cls, source_layer: nn.Module,
def _mcp_apply(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method.

The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
Expand Down Expand Up @@ -167,14 +168,65 @@ def can_replace_layer(cls, source_layer: nn.Module,
)


class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.
"""

def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)

bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output

output = output.view(*out_orig_shape)
return output

@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)


class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
"""
Differs from MergedQKVParallelLinearWithLora by slicing the
LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.
"""

def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
Expand Down
36 changes: 21 additions & 15 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,24 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
self.base_layer.head_size)

def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
return lora_b

def set_lora(
self,
index: int,
Expand All @@ -650,21 +668,8 @@ def set_lora(
):
self.reset_lora(index)
if self.tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)

self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
Expand All @@ -674,6 +679,7 @@ def set_lora(
lora_b.T, non_blocking=True)

@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
Expand All @@ -35,6 +36,7 @@
RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA,
Expand Down
Loading