Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Float8 Weight Only and FP8 weight + dynamic activation #740

Merged
merged 2 commits into from
Aug 30, 2024

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Aug 23, 2024

Summary

This PR makes some tweaks to the existing FP8 weight only work flow and adds float8_dynamic_activation_float8_weight to the quantize_ API.

New Apis made public:

from torchao.quantization import (
    quantize_,
    float8_weight_only,
    float8_dynamic_activation_float8_weight,
)
def float8_dynamic_activation_float8_weight(
    target_dtype: torch.dtype = torch.float8_e4m3fn,
    activation_dtype: torch.dtype = torch.float8_e4m3fn,
    mm_config: ScaledMMConfig = ScaledMMConfig(use_fast_accum=True)
):
    """
    Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers.
    Args:
        target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
        activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn.
        mm_config (ScaledMMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
    """
def float8_weight_only(target_dtype: torch.dtype = torch.float8_e4m3fn):
    """
    Applies float8 weight-only symmetric per-channel quantization to linear layers.
    
    Args:
        target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
    Note:
        The actual matmul will be computed in original precision of the weight tensor.
    """

TODO

  • Add Rowwise scaling option to top level api
  • Next would be add float8_static_activation and hooking into calibration flow
  • Adding these two APIs to autoquant
  • Proper way of ensuring hardware support is met when quantizing for these types

Changes

  • Adds Float8Layout, the main reason for doing this is the ability to pass the mm_config down to matmul via the layout_type. This also adds the ability to do fake "transpose". This may still be needed but I have since learned that we register dispatch entires to nn.fucntional.linear. Thus the weight will not be ran through transpose.
  • Wires up the TensorWise Scaling case to _scaled_mm
  • Updates the Quant_primitives to handle the Symmetric no zero-point float to float dequant routine.
  • Updates the AQTLayout's Typing to reflect the changes to dequant.
  • I added various type hints and docs to function that I found helpful
  • We were uncondtionally registering safe globals, gated that by behind the 2.5 flag
  • Option for inverse scaling

Memory snapshot compare

With no_grad():

Screenshot 2024-08-28 at 11 46 16 PM

Normal conditions:
Screenshot 2024-08-28 at 11 47 04 PM

FIX

@jerryzh168 If I decorate quantize affine:

@torch.no_grad()
def quantize_affine(
    input: torch.Tensor,
    block_size: Tuple[int, ...],
    scale: torch.Tensor,
    zero_point: Optional[torch.Tensor],
    output_dtype: torch.dtype,
    quant_min: Optional[Union[int, float]] = None,
    quant_max: Optional[Union[int, float]] = None,
    zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
    """

And

@torch.no_grad()
def choose_qparams_affine(
   input: torch.Tensor,
   mapping_type: MappingType,
   block_size: Tuple[int, ...],
   target_dtype: torch.dtype,
   quant_min: Optional[Union[int, float]] = None,
   quant_max: Optional[Union[int, float]] = None,
   eps: Optional[float] = None,
   scale_dtype: Optional[torch.dtype] = None,
   zero_point_dtype: Optional[torch.dtype] = None,
   preserve_zero: bool = True,
   zero_point_domain = ZeroPointDomain.INT,
) -> Tuple[torch.Tensor, torch.Tensor]:

I get the proper memory usage:
Screenshot 2024-08-28 at 11 57 28 PM

Is it okay If I land these as well or is there some other use case I am missing? I imagine all other input_tenosrs have been ints and havent propgated grads before

Perf micro benchmarks

Runner Script:
https://gist.github.com/drisspg/3baf802f8c8631df2a549840f39e7e0d

Trace: internal_link

Max Autotune
Needs this fix: pytorch/pytorch#134765
Trace: internal_link

Copy link

pytorch-bot bot commented Aug 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/740

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c771c64 with merge base 05224a9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 23, 2024
@drisspg drisspg force-pushed the working-changes branch 23 times, most recently from 287fc8f to 6e6462f Compare August 28, 2024 22:41
@drisspg drisspg changed the title Hack up AQT + FP8 Add Float8 Weight Only and FP8 weight + dynamic activation Aug 28, 2024
@drisspg drisspg force-pushed the working-changes branch 3 times, most recently from e0b3d31 to 7244571 Compare August 28, 2024 23:33
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good overall, just had some nit comments inline

@drisspg drisspg merged commit d0e6246 into main Aug 30, 2024
16 checks passed
@drisspg drisspg deleted the working-changes branch August 30, 2024 01:39
@@ -30,14 +30,19 @@ def addmm_float8_unwrapped(
output_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_fast_accum: bool = False,
inverse_scale: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean AffineQuantizedTensor uses scale in the same way Float8Tensor uses 1 / scale, or something else?

@@ -57,6 +59,7 @@
from .utils import _get_per_token_block_size
import logging
from .autoquant import autoquant, AutoQuantizableLinearWeight
from torchao.float8.float8_tensor import ScaledMMConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like this PR makes ScaledMMConfig a public API. Is this intended? If yes, can we move this object to float8/__init__.py, maybe make a dataclass, ensure it's documented, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah Ill put up a PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants