From 82726017e15373fe24b1c3185500d911d1459b1c Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 8 Oct 2024 13:01:26 -0700 Subject: [PATCH] Move and rename GranularityType -> Granularity Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI ghstack-source-id: 3525830c9b2ef33fd5fa22b93a1ace37f40971f9 Pull Request resolved: https://github.com/pytorch/ao/pull/1038 --- test/dtypes/test_affine_quantized_float.py | 8 +- test/quantization/test_observer.py | 4 +- torchao/_models/llama/eval.py | 4 +- torchao/_models/llama/generate.py | 2 +- torchao/prototype/awq/api.py | 2 +- torchao/prototype/awq/core.py | 9 +- torchao/quantization/README.md | 2 +- torchao/quantization/autoquant.py | 4 +- torchao/quantization/observer.py | 99 ++++------------------ torchao/quantization/quant_api.py | 4 +- torchao/quantization/quant_primitives.py | 69 +++++++++++++++ tutorials/calibration_flow/awq_like.py | 4 +- tutorials/calibration_flow/gptq_like.py | 2 +- tutorials/calibration_flow/static_quant.py | 4 +- 14 files changed, 115 insertions(+), 102 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 621e3596e..34f6bd2f1 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -26,11 +26,15 @@ float8_weight_only, quantize_, ) -from torchao.quantization.observer import PerRow, PerTensor from torchao.quantization.quant_api import ( float8_static_activation_float8_weight, ) -from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine +from torchao.quantization.quant_primitives import ( + MappingType, + PerRow, + PerTensor, + choose_qparams_affine, +) random.seed(0) torch.manual_seed(0) diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 8c8007871..3cca97f07 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -11,14 +11,14 @@ from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerAxis, - PerTensor, ) from torchao.quantization.quant_api import ( insert_observers_, ) from torchao.quantization.quant_primitives import ( MappingType, + PerAxis, + PerTensor, ) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 6d46e4587..1655e450d 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -24,9 +24,9 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, ) -from torchao.quantization.observer import PerRow, PerTensor from torchao._models._eval import TransformerEvalWrapper, InputRecorder from torchao._models.llama.model import prepare_inputs_for_model +from torchao.quantization.quant_primitives import PerRow, PerTensor from tokenizer import get_tokenizer import time @@ -255,4 +255,4 @@ def run_evaluation( args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, - ) \ No newline at end of file + ) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 270054e13..1971ec094 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -216,7 +216,7 @@ def main( float8_weight_only, float8_dynamic_activation_float8_weight, ) - from torchao.quantization.observer import PerTensor, PerRow + from torchao.quantization.quant_primitives import PerTensor, PerRow if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index e3a8827e2..6827fe391 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -3,11 +3,11 @@ from torchao.quantization.quant_primitives import ( MappingType, + PerGroup, ZeroPointDomain, _DTYPE_TO_QVALUE_BOUNDS, ) from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata -from torchao.quantization.observer import PerGroup from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType from torchao.dtypes import( diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 77810a2e4..15b0ec6c1 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -9,10 +9,11 @@ from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.quant_primitives import ( MappingType, + Granularity, ZeroPointDomain, ) from torchao.quantization.observer import ( - AffineQuantizedObserverBase, GranularityType + AffineQuantizedObserverBase, ) @@ -20,7 +21,7 @@ class AWQObserver(AffineQuantizedObserverBase): def __init__(self, weight: torch.Tensor, bias: torch.Tensor, - quantization_granularity: GranularityType, + quantization_granularity: Granularity, mapping_type: MappingType, target_dtype: torch.dtype, n_validation_examples: int, @@ -40,7 +41,7 @@ def __init__(self, Args: weight: The weight tensor to be observed. bias: The bias tensor to be observed. - quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point + quantization_granularity: Granularity which specifies how many weights share the same scale/zero point input_dtype: The data type of the input tensor. mapping_type: Always set to asymmetric target_dtype: The target data type of the quantized tensor @@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver): observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) observed_linear.weight = float_linear.weight observed_linear.bias = float_linear.bias - return observed_linear \ No newline at end of file + return observed_linear diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index c936b7ef8..9d7b04947 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.4+ from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight -from torchao.quantization.observer import PerTensor +from torchao.quantization.quant_api import PerTensor quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor())) ``` diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index a5568c4e1..8a02cccf2 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -13,11 +13,13 @@ from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .quant_primitives import ( + PerAxis, + PerRow, + PerTensor, safe_int_mm, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 from torchao.quantization.utils import quantize_activation_per_token_absmax -from torchao.quantization.observer import PerAxis, PerTensor, PerRow from torchao.float8.inference import Float8MMConfig import torch.nn.functional as F diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index bef4abe71..d702b54f5 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -3,12 +3,15 @@ _get_reduction_params, choose_qparams_affine_with_min_max, MappingType, + Granularity, + PerAxis, + PerRow, + PerTensor, ZeroPointDomain, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 from abc import ABCMeta, abstractmethod -from dataclasses import dataclass from typing import Tuple, Optional, Any from functools import partial import logging @@ -16,74 +19,6 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class GranularityType: - """ - Base class for representing the granularity of quantization. - - This class serves as a parent for specific granularity types used in - quantization operations, such as per-tensor or per-axis quantization. - """ - pass - -@dataclass(frozen=True) -class PerTensor(GranularityType): - """ - Represents per-tensor granularity in quantization. - - This granularity type calcualtes the quantization parameters - based off the entire tensor. - """ - pass - -@dataclass(frozen=True) -class PerAxis(GranularityType): - """ - Represents per-axis granularity in quantization. - - This granularity type calcualtes different quantization parameters - along a specified axis of the tensor. - - For example if the input tensor is shape [8, 16] and axis=0, then - the quantization parameters are calculated for each row of the tensor. - Giving a total of 8 quantization parameters. - - - Attributes: - axis (int): The axis along which reduction is performed. - """ - axis: int - -@dataclass(frozen=True) - -class PerGroup(GranularityType): - """ - Represents per-channel group granularity in quantization. - - This granularity type calcualtes different quantization parameters - for each group of elements. - - For example if the input tensor is shape [8, 16], and the group size is 4, then - the input tensor is reshaped to [64, 4] - quantization parameters are calculated for each group of 4 elements, - giving a total of 64 quantization parameters. - - Attributes: - group_size (int): The size of each quantization group - - """ - group_size: int - -class PerRow(GranularityType): - """ - Represents row-wise granularity in quantization. - - This is a special case of per-axis quantization and is unique to Float8 matmuls - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight - is quantized with a block_size of (1, weight.shape[1]). - """ - pass - # borrowed from torch.ao.quantization.observer class _PartialWrapper: def __init__(self, p): @@ -120,23 +55,23 @@ def _with_args(cls_or_self, *args, **kwargs): def get_block_size( - input_shape: Tuple[int, ...], granularity_type: GranularityType + input_shape: Tuple[int, ...], granularity: Granularity ) -> Tuple[int, ...]: """Get the block size based on the input shape and granularity type. Args: input_shape: The input tensor shape possibly more than 2 dimensions - granularity_type: The granularity type of the quantization + granularity: The granularity type of the quantization """ - if isinstance(granularity_type, PerTensor): + if isinstance(granularity, PerTensor): return input_shape - elif isinstance(granularity_type, PerAxis): + elif isinstance(granularity, PerAxis): block_size = list(input_shape) - block_size[granularity_type.axis] = 1 + block_size[granularity.axis] = 1 return tuple(block_size) - elif isinstance(granularity_type, PerRow): + elif isinstance(granularity, PerRow): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) - raise ValueError(f"Unsupported GranularityType: {granularity_type}") + raise ValueError(f"Unsupported Granularity: {granularity}") ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: @@ -146,7 +81,7 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module): """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) Args: - `granularity_type` and `block_size`: The granularity of the quantization, + `granularity` and `block_size`: The granularity of the quantization, must specify at least one, if both are specified `block_size` takes precedence Current supported granularity type are `PerTensor` and `PerAxis` other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` @@ -158,7 +93,7 @@ def __init__( self, mapping_type: MappingType, target_dtype: torch.dtype, - granularity_type: GranularityType, + granularity: Granularity, quant_min: Optional[int] = None, quant_max: Optional[int] = None, eps: Optional[float] = None, @@ -168,11 +103,11 @@ def __init__( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ): super().__init__() - assert granularity_type is not None, "granularity_type is None" + assert granularity is not None, "granularity is None" self.mapping_type = mapping_type self.target_dtype = target_dtype - self.granularity_type = granularity_type + self.granularity = granularity self.quant_min = quant_min self.quant_max = quant_max self.eps = eps @@ -202,8 +137,8 @@ def forward(self, input: torch.Tensor): return input input_detached = input.detach() - assert self.granularity_type is not None, "granularity_type is None" - block_size = get_block_size(input_detached.shape, self.granularity_type) + assert self.granularity is not None, "granularity is None" + block_size = get_block_size(input_detached.shape, self.granularity) shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input_detached.size() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6c4142506..758437468 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -54,6 +54,8 @@ from .quant_primitives import ( MappingType, + PerRow, + PerTensor, ZeroPointDomain, ) from .weight_only import WeightOnlyInt8QuantLinear @@ -71,7 +73,7 @@ ) from torchao.float8.inference import Float8MMConfig -from torchao.quantization.observer import PerTensor, PerRow, get_block_size +from torchao.quantization.observer import get_block_size logger = logging.getLogger(__name__) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index b1561e4cf..594bf8c42 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass from enum import Enum, auto from typing import List, Optional, Tuple, Dict, Callable, Union import torch, math @@ -64,6 +65,74 @@ class ZeroPointDomain(Enum): INT = auto() FLOAT = auto() +@dataclass(frozen=True) +class Granularity: + """ + Base class for representing the granularity of quantization. + + This class serves as a parent for specific granularity types used in + quantization operations, such as per-tensor or per-axis quantization. + """ + pass + +@dataclass(frozen=True) +class PerTensor(Granularity): + """ + Represents per-tensor granularity in quantization. + + This granularity type calcualtes the quantization parameters + based off the entire tensor. + """ + pass + +@dataclass(frozen=True) +class PerAxis(Granularity): + """ + Represents per-axis granularity in quantization. + + This granularity type calcualtes different quantization parameters + along a specified axis of the tensor. + + For example if the input tensor is shape [8, 16] and axis=0, then + the quantization parameters are calculated for each row of the tensor. + Giving a total of 8 quantization parameters. + + + Attributes: + axis (int): The axis along which reduction is performed. + """ + axis: int + +@dataclass(frozen=True) + +class PerGroup(Granularity): + """ + Represents per-channel group granularity in quantization. + + This granularity type calcualtes different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + group_size: int + +class PerRow(Granularity): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + pass + if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) diff --git a/tutorials/calibration_flow/awq_like.py b/tutorials/calibration_flow/awq_like.py index 037dbae0f..41a43bda5 100644 --- a/tutorials/calibration_flow/awq_like.py +++ b/tutorials/calibration_flow/awq_like.py @@ -22,11 +22,11 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, - PerAxis, ) from torchao.quantization.quant_primitives import ( MappingType, + PerTensor, + PerAxis, FP8_TYPES, ) diff --git a/tutorials/calibration_flow/gptq_like.py b/tutorials/calibration_flow/gptq_like.py index edb1b257e..07dd2876a 100644 --- a/tutorials/calibration_flow/gptq_like.py +++ b/tutorials/calibration_flow/gptq_like.py @@ -40,10 +40,10 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, ) from torchao.quantization.quant_primitives import ( MappingType, + PerTensor, fake_quantize_affine, ) diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index f75485d3d..d5469d432 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -17,11 +17,11 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, - PerAxis, ) from torchao.quantization.quant_primitives import ( MappingType, + PerTensor, + PerAxis, FP8_TYPES, )