Skip to content

Commit

Permalink
Add option to disable fake quant for 8da4w QAT
Browse files Browse the repository at this point in the history
Summary: This feature helps with model convergence during QAT.
The user can disable observation/fake quant for the first N
steps and renable them later, allowing the activation and
weight values to stabilize before applying quantization.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_disable_fake_quant
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_disable_fake_quant_backward

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar
  • Loading branch information
andrewor14 committed May 2, 2024
1 parent ac53d7f commit 56afc27
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 20 deletions.
98 changes: 98 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,104 @@ def test_qat_8da4w_quantizer(self):
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
"""
from torchao.quantization.prototype.qat import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
m2 = copy.deepcopy(m)
m3 = copy.deepcopy(m)
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
self.assertFalse(qat_model.linear1._fake_quant_enabled)
self.assertFalse(qat_model.linear2._fake_quant_enabled)
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)

# Disabled fake quant is just a normal linear
m2.linear1.weight = qat_model.linear1.weight
m2.linear2.weight = qat_model.linear2.weight
m2.sub.linear.weight = qat_model.sub.linear.weight
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
nn_out = m2(*x2)
torch.testing.assert_close(nn_out, qat_out, atol=0, rtol=0)

# Renable fake quant
qat_model.apply(enable_8da4w_fake_quant)
self.assertTrue(qat_model.linear1._fake_quant_enabled)
self.assertTrue(qat_model.linear2._fake_quant_enabled)
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)

# Fake quant should be applied as normal
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model2 = quantizer2.prepare(m3)
qat_model2.linear1.weight = qat_model.linear1.weight
qat_model2.linear2.weight = qat_model.linear2.weight
qat_model2.sub.linear.weight = qat_model.sub.linear.weight
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
qat_out2 = qat_model2(*x2)
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
"""
from torchao.quantization.prototype.qat import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
nn_model = copy.deepcopy(m)
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
nn_model.linear1.weight = qat_model.linear1.weight
nn_model.linear2.weight = qat_model.linear2.weight
nn_model.sub.linear.weight = qat_model.sub.linear.weight

# Simulate training for both models
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
optimizer2 = torch.optim.SGD(qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn1 = torch.nn.CrossEntropyLoss()
loss_fn2 = torch.nn.CrossEntropyLoss()
example_inputs = nn_model.example_inputs()
target = torch.randn(1, 64).float()
output1 = nn_model(*example_inputs)
output2 = qat_model(*example_inputs)
torch.testing.assert_close(output1, output2, atol=0, rtol=0)
loss1 = loss_fn1(output1, target)
loss2 = loss_fn2(output2, target)
optimizer1.zero_grad()
optimizer2.zero_grad()
loss1.backward()
loss2.backward()
optimizer1.step()
optimizer2.step()

# After 1 training step, weights should match exactly
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
67 changes: 47 additions & 20 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Tuple
from typing import Any, Optional, Tuple

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
Expand Down Expand Up @@ -129,30 +129,43 @@ def __init__(
self.groupsize = groupsize
self.precision = precision
self.scales_precision = scales_precision
self._fake_quant_enabled = True

def enable_fake_quant(self, enabled: bool = True):
self._fake_quant_enabled = enabled

def disable_fake_quant(self):
self.enable_fake_quant(False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# activations: int8 dynamic asymmetric quant
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
x, torch.int8, # dtype not used
)
x_fq = fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
if self._fake_quant_enabled:
(act_scales, act_zp) =_choose_qparams_per_token_asymmetric(
x, torch.int8, # dtype not used
)
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
x_fq = fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
else:
x_fq = x

# weights: int4 grouped per channel symmetric quant
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, 4, self.groupsize, self.scales_precision,
)
w_fq = fake_quantize_per_channel_group(
self.weight,
weight_scales,
weight_zp,
weight_qmin,
weight_qmax,
self.groupsize,
)
if self._fake_quant_enabled:
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, 4, self.groupsize, self.scales_precision,
)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
w_fq = fake_quantize_per_channel_group(
self.weight,
weight_scales,
weight_zp,
weight_qmin,
weight_qmax,
self.groupsize,
)
else:
w_fq = self.weight
return torch.nn.functional.linear(x_fq, w_fq)

# TODO: move this to common util
Expand All @@ -161,6 +174,20 @@ def _get_qmin_qmax(self, n_bit: int):
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

def enable_8da4w_fake_quant(mod: torch.nn.Module):
"""
Enable fake quantization for `Int8DynActInt4WeightQATLinear`.
"""
if isinstance(mod, Int8DynActInt4WeightQATLinear):
mod.enable_fake_quant()

def disable_8da4w_fake_quant(mod: torch.nn.Module):
"""
Disable fake quantization for `Int8DynActInt4WeightQATLinear`.
"""
if isinstance(mod, Int8DynActInt4WeightQATLinear):
mod.disable_fake_quant()


# ========================
# | QUANT PRIMITIVES |
Expand Down

0 comments on commit 56afc27

Please sign in to comment.