Skip to content

Commit

Permalink
Update base for Update on "Add generic fake quantized linear for QAT"
Browse files Browse the repository at this point in the history
**Summary:** This commit adds a generic fake quantized linear module
to replace the uses of the existing more specific QAT linears.
For example, `Int8DynActInt4WeightQATLinear` can be expressed
as follows:

```
from torchao.quantization.prototype.qat.api import FakeQuantizeConfig
from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear

activation_config = FakeQuantizeConfig(
    bit_width=8,
    granularity="per_token",
    symmetric=False,
    dynamic=True,
)
weight_config = FakeQuantizeConfig(
    bit_width=4,
    group_size=8,
    symmetric=True,
    dynamic=True,
)
fq_linear = FakeQuantizedLinear(
    16, 32, False, activation_config, weight_config,
)
```

The main motivation is to provide a more flexible way to perform
QAT on models with linear layers. Previously, we would have to
create a new linear class every time we wish to experiment with
different fake quantization settings, e.g. different group size
or different bit width. Now we can express this easily using a
single linear module.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_config
python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w
python test/quantization/test_qat.py -k test_fake_quantized_linear_4w

[ghstack-poisoned]
  • Loading branch information
andrewor14 committed Oct 8, 2024
1 parent d671826 commit d4332cb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
float8_static_activation_float8_weight,
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
MappingType,
PerRow,
PerTensor,
choose_qparams_affine,
)

random.seed(0)
Expand Down

0 comments on commit d4332cb

Please sign in to comment.