Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 #5921

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,13 @@ def weight_loader(self,
None)

if loaded_shard_id is None:
# Loaded weight is already packed.
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
Expand Down Expand Up @@ -571,8 +576,13 @@ def weight_loader(self,
None)

if loaded_shard_id is None:
# Loaded weight is already packed.
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
Expand Down
41 changes: 20 additions & 21 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase):
"""

def __init__(self, quant_config: Fp8Config):
self.fused_module_in_checkpoint = False
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()

Expand All @@ -111,6 +112,7 @@ def _create_scale_param(
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float8_e4m3fn).min
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this?

Copy link
Sponsor Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • For QKVColumnLinear, we create 3 scales (one for each logical partition)
  • There is only one scale for qkv on disk for phi
  • In this PR, we load this one scale into one of the 3 spots
  • In process_weights_after_loading, we select the max of the 3 scales to use
  • By initializing the scales to min, we are guaranteed to select the "real" scale from disk

layer.register_parameter(scale_name, scale)
set_weight_attrs(
scale, {
Expand Down Expand Up @@ -169,11 +171,15 @@ def create_weights(
**extra_weight_attrs)

def scales_shard_indexer(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Optional[Union[str,
int]]) -> Tuple[torch.Tensor, torch.Tensor]:
qkv_idxs = {"q": 0, "k": 1, "v": 2}

if isinstance(shard_id, int):
if shard_id is None:
shard_id = 0
self.fused_module_in_checkpoint = True
elif isinstance(shard_id, int):
pass
elif isinstance(shard_id, str):
if shard_id not in qkv_idxs:
Expand Down Expand Up @@ -205,15 +211,17 @@ def process_weights_after_loading(self, layer: Module) -> None:
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale = layer.weight_scale.max()
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
layer.weight_scale[idx])

layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end

if not self.fused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
layer.weight[start:end, :], layer.weight_scale[idx])

layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

# WEIGHT
Expand All @@ -227,10 +235,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
if self.quant_config.activation_scheme == "dynamic":
layer.input_scale = None
elif self.quant_config.activation_scheme == "static":
if not all_close_1d(layer.input_scale):
raise ValueError(
"All the input_scales for the logical weights of a "
f"layer must be equal. But got {layer.input_scale}")
mgoin marked this conversation as resolved.
Show resolved Hide resolved
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
Expand Down Expand Up @@ -317,11 +321,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
del layer.kv_scale


def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))


def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
Expand Down
Loading