Skip to content

Commit

Permalink
Update on "Add generic fake quantized linear for QAT"
Browse files Browse the repository at this point in the history
**Summary:** This commit adds a generic fake quantized linear module
to replace the uses of the existing more specific QAT linears.
For example, `Int8DynActInt4WeightQATLinear` can be expressed
as follows:

```
from torchao.quantization.prototype.qat.api import FakeQuantizeConfig
from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear

activation_config = FakeQuantizeConfig(
    bit_width=8,
    granularity="per_token",
    symmetric=False,
    dynamic=True,
)
weight_config = FakeQuantizeConfig(
    bit_width=4,
    group_size=8,
    symmetric=True,
    dynamic=True,
)
fq_linear = FakeQuantizedLinear(
    16, 32, False, activation_config, weight_config,
)
```

The main motivation is to provide a more flexible way to perform
QAT on models with linear layers. Previously, we would have to
create a new linear class every time we wish to experiment with
different fake quantization settings, e.g. different group size
or different bit width. Now we can express this easily using a
single linear module.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_config
python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w
python test/quantization/test_qat.py -k test_fake_quantized_linear_4w

[ghstack-poisoned]
  • Loading branch information
andrewor14 committed Oct 8, 2024
2 parents 8e5d2ea + dbad878 commit ab43744
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -54,7 +54,7 @@ def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerAxis(axis=0),
granularity=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -68,7 +68,7 @@ def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -87,7 +87,7 @@ def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
granularity=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -102,7 +102,7 @@ def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(0),
granularity=PerAxis(0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -121,7 +121,7 @@ def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
granularity=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -159,7 +159,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
weight_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand Down

0 comments on commit ab43744

Please sign in to comment.