Skip to content

Commit

Permalink
Move and rename GranularityType -> Granularity
Browse files Browse the repository at this point in the history
Summary: Move GranularityType to quant_primitives.py to be
consistent with other similar fields like MappingType and
ZeroPointDomain.

Test Plan: CI

ghstack-source-id: 3525830c9b2ef33fd5fa22b93a1ace37f40971f9
Pull Request resolved: #1038
  • Loading branch information
andrewor14 committed Oct 8, 2024
1 parent 85c7e9a commit 8272601
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 102 deletions.
8 changes: 6 additions & 2 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@
float8_weight_only,
quantize_,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.quant_api import (
float8_static_activation_float8_weight,
)
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
from torchao.quantization.quant_primitives import (
MappingType,
PerRow,
PerTensor,
choose_qparams_affine,
)

random.seed(0)
torch.manual_seed(0)
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerAxis,
PerTensor,
)
from torchao.quantization.quant_api import (
insert_observers_,
)
from torchao.quantization.quant_primitives import (
MappingType,
PerAxis,
PerTensor,
)


Expand Down
4 changes: 2 additions & 2 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.quantization.quant_primitives import PerRow, PerTensor

from tokenizer import get_tokenizer
import time
Expand Down Expand Up @@ -255,4 +255,4 @@ def run_evaluation(
args.calibration_limit,
args.calibration_seq_length,
args.pad_calibration_inputs,
)
)
2 changes: 1 addition & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def main(
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerTensor, PerRow
from torchao.quantization.quant_primitives import PerTensor, PerRow
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from torchao.quantization.quant_primitives import (
MappingType,
PerGroup,
ZeroPointDomain,
_DTYPE_TO_QVALUE_BOUNDS,
)
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
from torchao.quantization.observer import PerGroup
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
from torchao.dtypes import(
Expand Down
9 changes: 5 additions & 4 deletions torchao/prototype/awq/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization.quant_primitives import (
MappingType,
Granularity,
ZeroPointDomain,
)
from torchao.quantization.observer import (
AffineQuantizedObserverBase, GranularityType
AffineQuantizedObserverBase,
)


class AWQObserver(AffineQuantizedObserverBase):
def __init__(self,
weight: torch.Tensor,
bias: torch.Tensor,
quantization_granularity: GranularityType,
quantization_granularity: Granularity,
mapping_type: MappingType,
target_dtype: torch.dtype,
n_validation_examples: int,
Expand All @@ -40,7 +41,7 @@ def __init__(self,
Args:
weight: The weight tensor to be observed.
bias: The bias tensor to be observed.
quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point
quantization_granularity: Granularity which specifies how many weights share the same scale/zero point
input_dtype: The data type of the input tensor.
mapping_type: Always set to asymmetric
target_dtype: The target data type of the quantized tensor
Expand Down Expand Up @@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver):
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
observed_linear.weight = float_linear.weight
observed_linear.bias = float_linear.bias
return observed_linear
return observed_linear
2 changes: 1 addition & 1 deletion torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model)
```python
# for torch 2.4+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from torchao.quantization.observer import PerTensor
from torchao.quantization.quant_api import PerTensor
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
```

Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
PerAxis,
PerRow,
PerTensor,
safe_int_mm,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
from torchao.quantization.utils import quantize_activation_per_token_absmax
from torchao.quantization.observer import PerAxis, PerTensor, PerRow
from torchao.float8.inference import Float8MMConfig

import torch.nn.functional as F
Expand Down
99 changes: 17 additions & 82 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,87 +3,22 @@
_get_reduction_params,
choose_qparams_affine_with_min_max,
MappingType,
Granularity,
PerAxis,
PerRow,
PerTensor,
ZeroPointDomain,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Tuple, Optional, Any
from functools import partial
import logging

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class GranularityType:
"""
Base class for representing the granularity of quantization.
This class serves as a parent for specific granularity types used in
quantization operations, such as per-tensor or per-axis quantization.
"""
pass

@dataclass(frozen=True)
class PerTensor(GranularityType):
"""
Represents per-tensor granularity in quantization.
This granularity type calcualtes the quantization parameters
based off the entire tensor.
"""
pass

@dataclass(frozen=True)
class PerAxis(GranularityType):
"""
Represents per-axis granularity in quantization.
This granularity type calcualtes different quantization parameters
along a specified axis of the tensor.
For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.
Attributes:
axis (int): The axis along which reduction is performed.
"""
axis: int

@dataclass(frozen=True)

class PerGroup(GranularityType):
"""
Represents per-channel group granularity in quantization.
This granularity type calcualtes different quantization parameters
for each group of <group_size> elements.
For example if the input tensor is shape [8, 16], and the group size is 4, then
the input tensor is reshaped to [64, 4]
quantization parameters are calculated for each group of 4 elements,
giving a total of 64 quantization parameters.
Attributes:
group_size (int): The size of each quantization group
"""
group_size: int

class PerRow(GranularityType):
"""
Represents row-wise granularity in quantization.
This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
"""
pass

# borrowed from torch.ao.quantization.observer
class _PartialWrapper:
def __init__(self, p):
Expand Down Expand Up @@ -120,23 +55,23 @@ def _with_args(cls_or_self, *args, **kwargs):


def get_block_size(
input_shape: Tuple[int, ...], granularity_type: GranularityType
input_shape: Tuple[int, ...], granularity: Granularity
) -> Tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.
Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity_type: The granularity type of the quantization
granularity: The granularity type of the quantization
"""
if isinstance(granularity_type, PerTensor):
if isinstance(granularity, PerTensor):
return input_shape
elif isinstance(granularity_type, PerAxis):
elif isinstance(granularity, PerAxis):
block_size = list(input_shape)
block_size[granularity_type.axis] = 1
block_size[granularity.axis] = 1
return tuple(block_size)
elif isinstance(granularity_type, PerRow):
elif isinstance(granularity, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
raise ValueError(f"Unsupported GranularityType: {granularity_type}")
raise ValueError(f"Unsupported Granularity: {granularity}")


ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
Expand All @@ -146,7 +81,7 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
Args:
`granularity_type` and `block_size`: The granularity of the quantization,
`granularity` and `block_size`: The granularity of the quantization,
must specify at least one, if both are specified `block_size` takes precedence
Current supported granularity type are `PerTensor` and `PerAxis`
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
Expand All @@ -158,7 +93,7 @@ def __init__(
self,
mapping_type: MappingType,
target_dtype: torch.dtype,
granularity_type: GranularityType,
granularity: Granularity,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
Expand All @@ -168,11 +103,11 @@ def __init__(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
):
super().__init__()
assert granularity_type is not None, "granularity_type is None"
assert granularity is not None, "granularity is None"

self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.granularity_type = granularity_type
self.granularity = granularity
self.quant_min = quant_min
self.quant_max = quant_max
self.eps = eps
Expand Down Expand Up @@ -202,8 +137,8 @@ def forward(self, input: torch.Tensor):
return input

input_detached = input.detach()
assert self.granularity_type is not None, "granularity_type is None"
block_size = get_block_size(input_detached.shape, self.granularity_type)
assert self.granularity is not None, "granularity is None"
block_size = get_block_size(input_detached.shape, self.granularity)

shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input_detached.size()
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@

from .quant_primitives import (
MappingType,
PerRow,
PerTensor,
ZeroPointDomain,
)
from .weight_only import WeightOnlyInt8QuantLinear
Expand All @@ -71,7 +73,7 @@
)
from torchao.float8.inference import Float8MMConfig

from torchao.quantization.observer import PerTensor, PerRow, get_block_size
from torchao.quantization.observer import get_block_size

logger = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit 8272601

Please sign in to comment.