From 2669eb2db43362eb562ede2fdace49ef119df6e5 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 27 Jun 2024 17:44:26 +0000 Subject: [PATCH 1/3] enable loading fp8 phi --- vllm/model_executor/layers/linear.py | 14 ++++++- .../model_executor/layers/quantization/fp8.py | 38 +++++++++---------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 45f805547b414..90bc826cd4be8 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -387,8 +387,13 @@ def weight_loader(self, None) if loaded_shard_id is None: - # Loaded weight is already packed. + # Loaded weight is already fused on disk (qkv/mlp). if output_dim is None: + # If fp8 + scale, need to send to each shard. + if fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return @@ -571,8 +576,13 @@ def weight_loader(self, None) if loaded_shard_id is None: - # Loaded weight is already packed. + # Loaded weight is already fused on disk (qkv/mlp). if output_dim is None: + # If fp8 + scale, need to send to each shard. + if fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index bbf3cde54782d..229d85398c44a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase): """ def __init__(self, quant_config: Fp8Config): + self.qkv_mlp_fused_on_disk = False self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -111,6 +112,7 @@ def _create_scale_param( scale = Parameter(torch.empty(len(output_partition_sizes), dtype=torch.float32), requires_grad=False) + scale[:] = torch.finfo(torch.float8_e4m3fn).min layer.register_parameter(scale_name, scale) set_weight_attrs( scale, { @@ -170,10 +172,13 @@ def create_weights( def scales_shard_indexer( self, param: torch.Tensor, loaded_weight: torch.Tensor, - shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id: Optional[Union[str, int]]) -> Tuple[torch.Tensor, torch.Tensor]: qkv_idxs = {"q": 0, "k": 1, "v": 2} - if isinstance(shard_id, int): + if shard_id is None: + shard_id = 0 + self.qkv_mlp_fused_on_disk = True + elif isinstance(shard_id, int): pass elif isinstance(shard_id, str): if shard_id not in qkv_idxs: @@ -205,15 +210,17 @@ def process_weights_after_loading(self, layer: Module) -> None: # WEIGHT_SCALE / WEIGHT # Loop over logical weights, requantizing with single scale. max_w_scale = layer.weight_scale.max() - start = 0 - for idx, logical_width in enumerate(layer.logical_widths): - end = start + logical_width - weight_dq = per_tensor_dequantize(layer.weight[start:end, :], - layer.weight_scale[idx]) - - layer.weight[start:end, :] = per_tensor_quantize( - weight_dq, layer.weight_scale.max()) - start = end + + if not self.qkv_mlp_fused_on_disk: + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(layer.weight[start:end, :], + layer.weight_scale[idx]) + + layer.weight[start:end, :] = per_tensor_quantize( + weight_dq, layer.weight_scale.max()) + start = end layer.weight_scale = Parameter(max_w_scale, requires_grad=False) # WEIGHT @@ -227,10 +234,6 @@ def process_weights_after_loading(self, layer: Module) -> None: if self.quant_config.activation_scheme == "dynamic": layer.input_scale = None elif self.quant_config.activation_scheme == "static": - if not all_close_1d(layer.input_scale): - raise ValueError( - "All the input_scales for the logical weights of a " - f"layer must be equal. But got {layer.input_scale}") layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: @@ -317,11 +320,6 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.kv_scale -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) - - def per_tensor_quantize(tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) From 06fcd94cd215094d5bdea045e87293968ddf6f93 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 27 Jun 2024 17:55:09 +0000 Subject: [PATCH 2/3] formatted --- vllm/model_executor/layers/linear.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 90bc826cd4be8..8a574ad5468e1 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -582,7 +582,7 @@ def weight_loader(self, if fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer( param_data, loaded_weight, loaded_shard_id) - + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 229d85398c44a..24a300b0fa8e4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -171,8 +171,9 @@ def create_weights( **extra_weight_attrs) def scales_shard_indexer( - self, param: torch.Tensor, loaded_weight: torch.Tensor, - shard_id: Optional[Union[str, int]]) -> Tuple[torch.Tensor, torch.Tensor]: + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Optional[Union[str, + int]]) -> Tuple[torch.Tensor, torch.Tensor]: qkv_idxs = {"q": 0, "k": 1, "v": 2} if shard_id is None: @@ -215,8 +216,8 @@ def process_weights_after_loading(self, layer: Module) -> None: start = 0 for idx, logical_width in enumerate(layer.logical_widths): end = start + logical_width - weight_dq = per_tensor_dequantize(layer.weight[start:end, :], - layer.weight_scale[idx]) + weight_dq = per_tensor_dequantize( + layer.weight[start:end, :], layer.weight_scale[idx]) layer.weight[start:end, :] = per_tensor_quantize( weight_dq, layer.weight_scale.max()) From 25a88bce52b82788d7fd9edef7b5320579b2966b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 27 Jun 2024 19:13:33 +0000 Subject: [PATCH 3/3] address mgoin comment --- vllm/model_executor/layers/quantization/fp8.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 24a300b0fa8e4..1c760566c28d7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -98,7 +98,7 @@ class Fp8LinearMethod(LinearMethodBase): """ def __init__(self, quant_config: Fp8Config): - self.qkv_mlp_fused_on_disk = False + self.fused_module_in_checkpoint = False self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -178,7 +178,7 @@ def scales_shard_indexer( if shard_id is None: shard_id = 0 - self.qkv_mlp_fused_on_disk = True + self.fused_module_in_checkpoint = True elif isinstance(shard_id, int): pass elif isinstance(shard_id, str): @@ -212,7 +212,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # Loop over logical weights, requantizing with single scale. max_w_scale = layer.weight_scale.max() - if not self.qkv_mlp_fused_on_disk: + if not self.fused_module_in_checkpoint: start = 0 for idx, logical_width in enumerate(layer.logical_widths): end = start + logical_width