Skip to content

Commit

Permalink
Update base for Update on "Add generic fake quantized linear for QAT"
Browse files Browse the repository at this point in the history
**Summary:** This commit adds a generic fake quantized linear module
to replace the uses of the existing more specific QAT linears.
For example, `Int8DynActInt4WeightQATLinear` can be expressed
as follows:

```
from torchao.quantization.prototype.qat.api import FakeQuantizeConfig
from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear

activation_config = FakeQuantizeConfig(
    bit_width=8,
    granularity="per_token",
    symmetric=False,
    dynamic=True,
)
weight_config = FakeQuantizeConfig(
    bit_width=4,
    group_size=8,
    symmetric=True,
    dynamic=True,
)
fq_linear = FakeQuantizedLinear(
    16, 32, False, activation_config, weight_config,
)
```

The main motivation is to provide a more flexible way to perform
QAT on models with linear layers. Previously, we would have to
create a new linear class every time we wish to experiment with
different fake quantization settings, e.g. different group size
or different bit width. Now we can express this easily using a
single linear module.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_config
python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w
python test/quantization/test_qat.py -k test_fake_quantized_linear_4w

[ghstack-poisoned]
  • Loading branch information
andrewor14 committed Oct 8, 2024
1 parent 75fcd21 commit d671826
Show file tree
Hide file tree
Showing 12 changed files with 27 additions and 16 deletions.
8 changes: 6 additions & 2 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@
float8_weight_only,
quantize_,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.quant_api import (
float8_static_activation_float8_weight,
)
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
MappingType,
PerRow,
PerTensor,
)

random.seed(0)
torch.manual_seed(0)
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerAxis,
PerTensor,
)
from torchao.quantization.quant_api import (
insert_observers_,
)
from torchao.quantization.quant_primitives import (
MappingType,
PerAxis,
PerTensor,
)


Expand Down
4 changes: 2 additions & 2 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.quantization.quant_primitives import PerRow, PerTensor

from tokenizer import get_tokenizer
import time
Expand Down Expand Up @@ -255,4 +255,4 @@ def run_evaluation(
args.calibration_limit,
args.calibration_seq_length,
args.pad_calibration_inputs,
)
)
2 changes: 1 addition & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def main(
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerTensor, PerRow
from torchao.quantization.quant_primitives import PerTensor, PerRow
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from torchao.quantization.quant_primitives import (
MappingType,
PerGroup,
ZeroPointDomain,
_DTYPE_TO_QVALUE_BOUNDS,
)
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
from torchao.quantization.observer import PerGroup
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
from torchao.dtypes import(
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model)
```python
# for torch 2.4+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from torchao.quantization.observer import PerTensor
from torchao.quantization.quant_api import PerTensor
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
```

Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
PerAxis,
PerRow,
PerTensor,
safe_int_mm,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
from torchao.quantization.utils import quantize_activation_per_token_absmax
from torchao.quantization.observer import PerAxis, PerTensor, PerRow
from torchao.float8.inference import Float8MMConfig

import torch.nn.functional as F
Expand Down
3 changes: 3 additions & 0 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
choose_qparams_affine_with_min_max,
MappingType,
Granularity,
PerAxis,
PerRow,
PerTensor,
ZeroPointDomain,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@

from .quant_primitives import (
MappingType,
PerRow,
PerTensor,
ZeroPointDomain,
)
from .weight_only import WeightOnlyInt8QuantLinear
Expand All @@ -71,7 +73,7 @@
)
from torchao.float8.inference import Float8MMConfig

from torchao.quantization.observer import PerTensor, PerRow, get_block_size
from torchao.quantization.observer import get_block_size

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions tutorials/calibration_flow/awq_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
PerAxis,
)
from torchao.quantization.quant_primitives import (
MappingType,
PerTensor,
PerAxis,
FP8_TYPES,
)

Expand Down
2 changes: 1 addition & 1 deletion tutorials/calibration_flow/gptq_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
PerTensor,
fake_quantize_affine,
)

Expand Down
4 changes: 2 additions & 2 deletions tutorials/calibration_flow/static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
PerAxis,
)
from torchao.quantization.quant_primitives import (
MappingType,
PerTensor,
PerAxis,
FP8_TYPES,
)

Expand Down

0 comments on commit d671826

Please sign in to comment.