Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to disable fake quant for 8da4w QAT #198

Merged
merged 4 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,104 @@ def test_qat_8da4w_quantizer(self):
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
"""
from torchao.quantization.prototype.qat import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
m2 = copy.deepcopy(m)
m3 = copy.deepcopy(m)
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
self.assertFalse(qat_model.linear1._fake_quant_enabled)
self.assertFalse(qat_model.linear2._fake_quant_enabled)
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)

# Disabled fake quant is just a normal linear
m2.linear1.weight = qat_model.linear1.weight
m2.linear2.weight = qat_model.linear2.weight
m2.sub.linear.weight = qat_model.sub.linear.weight
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
nn_out = m2(*x2)
torch.testing.assert_close(nn_out, qat_out, atol=0, rtol=0)

# Renable fake quant
qat_model.apply(enable_8da4w_fake_quant)
self.assertTrue(qat_model.linear1._fake_quant_enabled)
self.assertTrue(qat_model.linear2._fake_quant_enabled)
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)

# Fake quant should be applied as normal
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model2 = quantizer2.prepare(m3)
qat_model2.linear1.weight = qat_model.linear1.weight
qat_model2.linear2.weight = qat_model.linear2.weight
qat_model2.sub.linear.weight = qat_model.sub.linear.weight
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
qat_out2 = qat_model2(*x2)
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
"""
from torchao.quantization.prototype.qat import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
nn_model = copy.deepcopy(m)
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
nn_model.linear1.weight = qat_model.linear1.weight
nn_model.linear2.weight = qat_model.linear2.weight
nn_model.sub.linear.weight = qat_model.sub.linear.weight

# Simulate training for both models
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
optimizer2 = torch.optim.SGD(qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn1 = torch.nn.CrossEntropyLoss()
loss_fn2 = torch.nn.CrossEntropyLoss()
example_inputs = nn_model.example_inputs()
target = torch.randn(1, 64).float()
output1 = nn_model(*example_inputs)
output2 = qat_model(*example_inputs)
torch.testing.assert_close(output1, output2, atol=0, rtol=0)
loss1 = loss_fn1(output1, target)
loss2 = loss_fn2(output2, target)
optimizer1.zero_grad()
optimizer2.zero_grad()
loss1.backward()
loss2.backward()
optimizer1.step()
optimizer2.step()

# After 1 training step, weights should match exactly
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
67 changes: 47 additions & 20 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Tuple
from typing import Any, Optional, Tuple

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
Expand Down Expand Up @@ -129,30 +129,43 @@ def __init__(
self.groupsize = groupsize
self.precision = precision
self.scales_precision = scales_precision
self._fake_quant_enabled = True

def enable_fake_quant(self, enabled: bool = True):
self._fake_quant_enabled = enabled

def disable_fake_quant(self):
self.enable_fake_quant(False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# activations: int8 dynamic asymmetric quant
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
x, torch.int8, # dtype not used
)
x_fq = fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
if self._fake_quant_enabled:
(act_scales, act_zp) =_choose_qparams_per_token_asymmetric(
x, torch.int8, # dtype not used
)
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
x_fq = fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
else:
x_fq = x

# weights: int4 grouped per channel symmetric quant
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, 4, self.groupsize, self.scales_precision,
)
w_fq = fake_quantize_per_channel_group(
self.weight,
weight_scales,
weight_zp,
weight_qmin,
weight_qmax,
self.groupsize,
)
if self._fake_quant_enabled:
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, 4, self.groupsize, self.scales_precision,
)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
w_fq = fake_quantize_per_channel_group(
self.weight,
weight_scales,
weight_zp,
weight_qmin,
weight_qmax,
self.groupsize,
)
else:
w_fq = self.weight
return torch.nn.functional.linear(x_fq, w_fq)

# TODO: move this to common util
Expand All @@ -161,6 +174,20 @@ def _get_qmin_qmax(self, n_bit: int):
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

def enable_8da4w_fake_quant(mod: torch.nn.Module):
"""
Enable fake quantization for `Int8DynActInt4WeightQATLinear`.
"""
if isinstance(mod, Int8DynActInt4WeightQATLinear):
mod.enable_fake_quant()

def disable_8da4w_fake_quant(mod: torch.nn.Module):
"""
Disable fake quantization for `Int8DynActInt4WeightQATLinear`.
"""
if isinstance(mod, Int8DynActInt4WeightQATLinear):
mod.disable_fake_quant()


# ========================
# | QUANT PRIMITIVES |
Expand Down
Loading