Skip to content

Commit

Permalink
[Bugfix][VLM] Make apply_fp8_linear work with >2D input (vllm-project…
Browse files Browse the repository at this point in the history
…#9812)

Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
mgoin authored and NickLucche committed Oct 31, 2024
1 parent 06218f1 commit 1910637
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,26 @@ def apply_fp8_linear(
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.

# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]

# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_2d,
input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic)

# Fused GEMM_DQ
return ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return output.view(*output_shape)

# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
Expand All @@ -119,7 +124,7 @@ def apply_fp8_linear(
# for matrices with batch dimension > 16.
# This could change in the future.
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_2d,
input_scale,
num_token_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)
Expand All @@ -138,8 +143,10 @@ def apply_fp8_linear(
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
return torch.narrow(output[0], 0, 0, input.shape[0])
return torch.narrow(output, 0, 0, input.shape[0])
output = output[0]

return torch.narrow(output, 0, 0,
input_2d.shape[0]).view(*output_shape)

else:
# Fallback for channelwise case, where we use unfused DQ
Expand Down Expand Up @@ -176,15 +183,15 @@ def apply_fp8_linear(
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])

# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype)
return output.to(dtype=input.dtype).view(*output_shape)


def apply_int8_linear(
Expand Down

0 comments on commit 1910637

Please sign in to comment.